mirror of
https://github.com/davidallendj/opaal.git
synced 2025-12-20 03:27:02 -07:00
Major refactoring and code restructure
This commit is contained in:
parent
72adbe1f0d
commit
6d63211d35
10 changed files with 454 additions and 859 deletions
|
|
@ -1,424 +0,0 @@
|
||||||
package opaal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"davidallendj/opaal/internal/oidc"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/davidallendj/go-utils/util"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jws"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: change authorization code flow to use these instead
|
|
||||||
type AuthorizationCodeFlowEndpoints struct {
|
|
||||||
Login string
|
|
||||||
Token string
|
|
||||||
Identities string
|
|
||||||
TrustedIssuer string
|
|
||||||
Register string
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthorizationCodeWithConfig(config *Config, server *Server, client *Client, idp *oidc.IdentityProvider) error {
|
|
||||||
// check preconditions are met
|
|
||||||
err := verifyParams(config, server, client, idp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// build the authorization URL to redirect user for social sign-in
|
|
||||||
state := config.Authentication.Flows["authorization_code"]["state"]
|
|
||||||
var authorizationUrl = client.BuildAuthorizationUrl(idp.Endpoints.Authorization, state)
|
|
||||||
|
|
||||||
// print the authorization URL for sharing
|
|
||||||
fmt.Printf("Login with identity provider:\n\n %s/login\n %s\n\n",
|
|
||||||
server.GetListenAddr(), authorizationUrl,
|
|
||||||
)
|
|
||||||
|
|
||||||
// automatically open browser to initiate login flow (only useful for testing and debugging)
|
|
||||||
if config.Options.OpenBrowser {
|
|
||||||
util.OpenUrl(authorizationUrl)
|
|
||||||
}
|
|
||||||
|
|
||||||
// authorize oauth client and listen for callback from provider
|
|
||||||
fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", server.GetListenAddr())
|
|
||||||
code, err := server.WaitForAuthorizationCode(authorizationUrl, "")
|
|
||||||
if errors.Is(err, http.ErrServerClosed) {
|
|
||||||
fmt.Printf("\n=========================================\nServer closed.\n=========================================\n\n")
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("failed to start server: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// start up another server in background to listen for success or failures
|
|
||||||
d := StartListener(server)
|
|
||||||
|
|
||||||
// use code from response and exchange for bearer token (with ID token)
|
|
||||||
bearerToken, err := client.FetchTokenFromAuthenticationServer(
|
|
||||||
code,
|
|
||||||
idp.Endpoints.Token,
|
|
||||||
state,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to fetch token from issuer: %v", err)
|
|
||||||
}
|
|
||||||
// fmt.Printf("%v\n", string(bearerToken))
|
|
||||||
|
|
||||||
// unmarshal data to get id_token and access_token
|
|
||||||
var data map[string]any
|
|
||||||
err = json.Unmarshal([]byte(bearerToken), &data)
|
|
||||||
if err != nil || data == nil {
|
|
||||||
return fmt.Errorf("failed to unmarshal token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure we have an ID token
|
|
||||||
if data["id_token"] == nil {
|
|
||||||
return fmt.Errorf("no ID token found...aborting")
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract ID token from bearer as JSON string for easy consumption
|
|
||||||
idToken := data["id_token"].(string)
|
|
||||||
idJwtSegments, err := util.DecodeJwt(idToken)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("failed to parse ID token: %v\n", err)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("id_token: %v\n", idToken)
|
|
||||||
if config.Options.DecodeIdToken {
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("failed to decode JWT: %v\n", err)
|
|
||||||
} else {
|
|
||||||
for i, segment := range idJwtSegments {
|
|
||||||
// don't print last segment (signatures)
|
|
||||||
if i == len(idJwtSegments)-1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
fmt.Printf("%s\n", string(segment))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract the access token to get the scopes
|
|
||||||
accessToken := data["access_token"].(string)
|
|
||||||
accessJwtSegments, err := util.DecodeJwt(accessToken)
|
|
||||||
if err != nil || len(accessJwtSegments) <= 0 {
|
|
||||||
fmt.Printf("failed to parse access token: %v\n", err)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("access_token (from identity provider): %v\n", accessToken)
|
|
||||||
if config.Options.DecodeIdToken {
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("failed to decode JWT: %v\n", err)
|
|
||||||
} else {
|
|
||||||
for i, segment := range accessJwtSegments {
|
|
||||||
// don't print last segment (signatures)
|
|
||||||
if i == len(accessJwtSegments)-1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
fmt.Printf("%s\n", string(segment))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
}
|
|
||||||
|
|
||||||
if !config.Options.ForwardToken {
|
|
||||||
|
|
||||||
// TODO: implement our own JWT to send to Hydra
|
|
||||||
// 1. verify that the JWT from the issuer is valid
|
|
||||||
key, ok := idp.Jwks.Key(0)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("no key found in key set")
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKey(jwa.RS256, key))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse ID token: %v", err)
|
|
||||||
}
|
|
||||||
_, err = jwt.ParseString(accessToken, jwt.WithKeySet(idp.Jwks))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse access token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = jws.Verify([]byte(idToken), jws.WithKeySet(idp.Jwks), jws.WithValidateKey(true))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to verify JWT: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. create a new JWKS (or just JWK) to be verified
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate private RSA k-ey: %v", err)
|
|
||||||
}
|
|
||||||
privateJwk, err := jwk.FromRaw(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create private JWK: %v", err)
|
|
||||||
}
|
|
||||||
publicJwk, err := jwk.PublicKeyOf(privateJwk)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create public JWK: %v", err)
|
|
||||||
}
|
|
||||||
publicJwk.Set("kid", uuid.New().String())
|
|
||||||
|
|
||||||
// 3. add opaal's server host as a trusted issuer with JWK
|
|
||||||
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
|
||||||
res, err := client.AddTrustedIssuer(
|
|
||||||
config.Authorization.RequestUrls.TrustedIssuers,
|
|
||||||
server.Addr,
|
|
||||||
publicJwk,
|
|
||||||
"1",
|
|
||||||
time.Second*3600,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add trusted issuer: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("%v\n", string(res))
|
|
||||||
|
|
||||||
// 4. create a new JWT based on the claims from the identity provider and sign
|
|
||||||
payload := parsedIdToken.PrivateClaims()
|
|
||||||
payload["iss"] = server.Addr
|
|
||||||
payload["aud"] = []string{config.Authorization.RequestUrls.Token}
|
|
||||||
payload["iat"] = time.Now().Unix()
|
|
||||||
payload["nbf"] = time.Now().Unix()
|
|
||||||
payload["exp"] = time.Now().Add(time.Second * 3600).Unix()
|
|
||||||
payload["sub"] = "1"
|
|
||||||
payloadJson, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
|
||||||
}
|
|
||||||
newToken, err := jws.Sign(payloadJson, jws.WithJSON(), jws.WithKey(jwa.RS256, privateJwk))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to sign token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// sig = rsasha256(b64urlencode(header) + "." + b64urlencode(payload))
|
|
||||||
// signature := util.EncodeBase64() + util.EncodeBase64() +
|
|
||||||
|
|
||||||
// 5. dynamically register new OAuth client and authorize it
|
|
||||||
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
|
|
||||||
res, err = client.RegisterOAuthClient(config.Authorization.RequestUrls.Register, []string{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to register client: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("%v\n", string(res))
|
|
||||||
|
|
||||||
// extract the client info from response
|
|
||||||
var clientData map[string]any
|
|
||||||
err = json.Unmarshal(res, &clientData)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to unmarshal client data: %v", err)
|
|
||||||
} else {
|
|
||||||
// check for error first
|
|
||||||
errJson := clientData["error"]
|
|
||||||
if errJson == nil {
|
|
||||||
client.Id = clientData["client_id"].(string)
|
|
||||||
client.Secret = clientData["client_secret"].(string)
|
|
||||||
} else {
|
|
||||||
// delete client and create again
|
|
||||||
fmt.Printf("Attempting to delete client...\n")
|
|
||||||
err := client.DeleteOAuthClient(config.Authorization.RequestUrls.Clients)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete OAuth client: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("Attempting to re-create client...\n")
|
|
||||||
res, err := client.CreateOAuthClient(config.Authorization.RequestUrls.Clients, []string{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to register client: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("%v\n", string(res))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// authorize the client
|
|
||||||
// fmt.Printf("Attempting to authorize client...\n")
|
|
||||||
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("failed to authorize client: %v", err)
|
|
||||||
// }
|
|
||||||
// fmt.Printf("%v\n", string(res))
|
|
||||||
|
|
||||||
// 6. send JWT to authorization server and receive a access token
|
|
||||||
if config.Authorization.RequestUrls.Token != "" {
|
|
||||||
fmt.Printf("Fetching access token from authorization server...\n")
|
|
||||||
res, err := client.PerformTokenGrant(config.Authorization.RequestUrls.Token, string(newToken))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to fetch access token: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("%s\n", res)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// extract the scope from access token claims
|
|
||||||
// var scope []string
|
|
||||||
// var accessJsonPayload map[string]any
|
|
||||||
// var accessJwtPayload []byte = accessJwtSegments[1]
|
|
||||||
// if accessJsonPayload != nil {
|
|
||||||
// err := json.Unmarshal(accessJwtPayload, &accessJsonPayload)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("failed to unmarshal JWT: %v", err)
|
|
||||||
// }
|
|
||||||
// scope = idJsonPayload["scope"].([]string)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// create a new identity with identity and session manager if url is provided
|
|
||||||
// if config.RequestUrls.Identities != "" {
|
|
||||||
// fmt.Printf("Attempting to create a new identity...\n")
|
|
||||||
// err := client.CreateIdentity(config.RequestUrls.Identities, idToken)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("failed to create new identity: %v", err)
|
|
||||||
// }
|
|
||||||
// _, err = client.FetchIdentities(config.RequestUrls.Identities)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("failed to fetch identities: %v", err)
|
|
||||||
// }
|
|
||||||
// fmt.Printf("Created new identity successfully.\n\n")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// extract the subject from ID token claims
|
|
||||||
var subject string
|
|
||||||
var audience []string
|
|
||||||
var idJsonPayload map[string]any
|
|
||||||
var idJwtPayload []byte = idJwtSegments[1]
|
|
||||||
if idJwtPayload != nil {
|
|
||||||
err := json.Unmarshal(idJwtPayload, &idJsonPayload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to unmarshal JWT: %v", err)
|
|
||||||
}
|
|
||||||
subject = idJsonPayload["sub"].(string)
|
|
||||||
audType := reflect.ValueOf(idJsonPayload["aud"])
|
|
||||||
switch audType.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
audience = append(audience, idJsonPayload["aud"].(string))
|
|
||||||
case reflect.Array:
|
|
||||||
audience = idJsonPayload["aud"].([]string)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to extract subject from ID token claims")
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetch JWKS and add issuer to authentication server to submit ID token
|
|
||||||
fmt.Printf("Fetching JWKS from authentication server for verification...\n")
|
|
||||||
err = idp.FetchJwks()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to fetch JWK: %v", err)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("Successfully retrieved JWK from authentication server.\n\n")
|
|
||||||
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
|
||||||
res, err := client.AddTrustedIssuerWithIdentityProvider(
|
|
||||||
config.Authorization.RequestUrls.TrustedIssuers,
|
|
||||||
idp,
|
|
||||||
subject,
|
|
||||||
time.Duration(1000),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add trusted issuer: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("%v\n", string(res))
|
|
||||||
}
|
|
||||||
|
|
||||||
// add client ID to audience
|
|
||||||
audience = append(audience, client.Id)
|
|
||||||
audience = append(audience, "http://127.0.0.1:4444/oauth2/token")
|
|
||||||
|
|
||||||
// try and register a new client with authorization server
|
|
||||||
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
|
|
||||||
res, err := client.RegisterOAuthClient(config.Authorization.RequestUrls.Register, audience)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to register client: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("%v\n", string(res))
|
|
||||||
|
|
||||||
// extract the client info from response
|
|
||||||
var clientData map[string]any
|
|
||||||
err = json.Unmarshal(res, &clientData)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to unmarshal client data: %v", err)
|
|
||||||
} else {
|
|
||||||
// check for error first
|
|
||||||
errJson := clientData["error"]
|
|
||||||
if errJson == nil {
|
|
||||||
client.Id = clientData["client_id"].(string)
|
|
||||||
client.Secret = clientData["client_secret"].(string)
|
|
||||||
} else {
|
|
||||||
// delete client and create again
|
|
||||||
fmt.Printf("Attempting to delete client...\n")
|
|
||||||
err := client.DeleteOAuthClient(config.Authorization.RequestUrls.Clients)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete OAuth client: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("Attempting to re-create client...\n")
|
|
||||||
res, err := client.CreateOAuthClient(config.Authorization.RequestUrls.Clients, audience)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to register client: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("%v\n", string(res))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// authorize the client
|
|
||||||
// fmt.Printf("Attempting to authorize client...\n")
|
|
||||||
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("failed to authorize client: %v", err)
|
|
||||||
// }
|
|
||||||
// fmt.Printf("%v\n", string(res))
|
|
||||||
|
|
||||||
// use ID token/user info to fetch access token from authentication server
|
|
||||||
if config.Authorization.RequestUrls.Token != "" {
|
|
||||||
fmt.Printf("Fetching access token from authorization server...\n")
|
|
||||||
res, err := client.PerformTokenGrant(config.Authorization.RequestUrls.Token, idToken)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to fetch access token: %v", err)
|
|
||||||
}
|
|
||||||
fmt.Printf("%s\n", res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var access_token []byte
|
|
||||||
d <- access_token
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyParams(config *Config, server *Server, client *Client, idp *oidc.IdentityProvider) error {
|
|
||||||
// make sure we have a valid server and client
|
|
||||||
if server == nil {
|
|
||||||
return fmt.Errorf("server not initialized or valid (server == nil)")
|
|
||||||
}
|
|
||||||
if client == nil {
|
|
||||||
return fmt.Errorf("client not initialized or valid (client == nil)")
|
|
||||||
}
|
|
||||||
if idp == nil {
|
|
||||||
return fmt.Errorf("identity provider not initialized or valid (idp == nil)")
|
|
||||||
}
|
|
||||||
// check if all appropriate parameters are set in config
|
|
||||||
if !HasRequiredConfigParams(config) {
|
|
||||||
return fmt.Errorf("required params not set correctly or missing")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func StartListener(server *Server) chan []byte {
|
|
||||||
d := make(chan []byte)
|
|
||||||
quit := make(chan bool)
|
|
||||||
|
|
||||||
go server.Serve(d)
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-d:
|
|
||||||
fmt.Printf("got access token")
|
|
||||||
quit <- true
|
|
||||||
case <-quit:
|
|
||||||
close(d)
|
|
||||||
close(quit)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
package opaal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/cookiejar"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/davidallendj/go-utils/mathx"
|
|
||||||
"golang.org/x/net/publicsuffix"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
http.Client
|
|
||||||
Id string `yaml:"id"`
|
|
||||||
Secret string `yaml:"secret"`
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
Description string `yaml:"description"`
|
|
||||||
Issuer string `yaml:"issuer"`
|
|
||||||
RegistrationAccessToken string `yaml:"registration-access-token"`
|
|
||||||
RedirectUris []string `yaml:"redirect-uris"`
|
|
||||||
Scope []string `yaml:"scope"`
|
|
||||||
FlowId string
|
|
||||||
CsrfToken string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClient() *Client {
|
|
||||||
return &Client{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientWithConfig(config *Config) *Client {
|
|
||||||
// make sure config is valid
|
|
||||||
if config == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure we have at least one client
|
|
||||||
clients := config.Authentication.Clients
|
|
||||||
if len(clients) <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// use the first client found by default
|
|
||||||
return &Client{
|
|
||||||
Id: clients[0].Id,
|
|
||||||
Secret: clients[0].Secret,
|
|
||||||
Name: clients[0].Name,
|
|
||||||
Issuer: clients[0].Issuer,
|
|
||||||
Scope: clients[0].Scope,
|
|
||||||
RedirectUris: clients[0].RedirectUris,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientWithConfigByIndex(config *Config, index int) *Client {
|
|
||||||
size := len(config.Authentication.Clients)
|
|
||||||
index = mathx.Clamp(index, 0, size)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientWithConfigByName(config *Config, name string) *Client {
|
|
||||||
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
|
|
||||||
return c.Name == name
|
|
||||||
})
|
|
||||||
if index >= 0 {
|
|
||||||
return &config.Authentication.Clients[index]
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientWithConfigByProvider(config *Config, issuer string) *Client {
|
|
||||||
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
|
|
||||||
return c.Issuer == issuer
|
|
||||||
})
|
|
||||||
|
|
||||||
if index >= 0 {
|
|
||||||
return &config.Authentication.Clients[index]
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientWithConfigById(config *Config, id string) *Client {
|
|
||||||
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
|
|
||||||
return c.Id == id
|
|
||||||
})
|
|
||||||
if index >= 0 {
|
|
||||||
return &config.Authentication.Clients[index]
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client *Client) ClearCookies() {
|
|
||||||
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
|
|
||||||
client.Jar = jar
|
|
||||||
}
|
|
||||||
|
|
@ -1,156 +0,0 @@
|
||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"davidallendj/opaal/internal/oidc"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
|
||||||
)
|
|
||||||
|
|
||||||
func CreateIdentityProvidersIfNotExists(path string) (*sqlx.DB, error) {
|
|
||||||
schema := `
|
|
||||||
CREATE TABLE IF NOT EXISTS identity_providers (
|
|
||||||
issuer TEXT NOT NULL,
|
|
||||||
authorization_endpoint TEXT,
|
|
||||||
token_endpoint TEXT,
|
|
||||||
revocation_endpoint TEXT,
|
|
||||||
introspection_endpoint TEXT,
|
|
||||||
userinfo_endpoint TEXT,
|
|
||||||
jwks_uri TEXT,
|
|
||||||
response_types_supported TEXT,
|
|
||||||
response_modes_supported TEXT,
|
|
||||||
grant_types_supported TEXT,
|
|
||||||
token_endpoint_auth_methods_supported TEXT,
|
|
||||||
subject_types_supported TEXT,
|
|
||||||
id_token_signing_alg_values_supported TEXT,
|
|
||||||
claim_types_supported TEXT,
|
|
||||||
claims_supported TEXT,
|
|
||||||
jwks TEXT,
|
|
||||||
|
|
||||||
PRIMARY KEY (issuer)
|
|
||||||
);
|
|
||||||
`
|
|
||||||
db, err := sqlx.Open("sqlite3", path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not open database: %v", err)
|
|
||||||
}
|
|
||||||
db.MustExec(schema)
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertIdentityProviders(path string, providers *[]oidc.IdentityProvider) error {
|
|
||||||
if providers == nil {
|
|
||||||
return fmt.Errorf("states == nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// create database if it doesn't already exist
|
|
||||||
db, err := CreateIdentityProvidersIfNotExists(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert all probe states into db
|
|
||||||
tx := db.MustBegin()
|
|
||||||
for _, state := range *providers {
|
|
||||||
sql := `INSERT OR REPLACE INTO identity_providers
|
|
||||||
(
|
|
||||||
issuer,
|
|
||||||
authorization_endpoint,
|
|
||||||
token_endpoint,
|
|
||||||
revocation_endpoint,
|
|
||||||
introspection_endpoint,
|
|
||||||
userinfo_endpoint,
|
|
||||||
jwks_uri,
|
|
||||||
response_types_supported,
|
|
||||||
response_modes_supported,
|
|
||||||
grant_types_supported,
|
|
||||||
token_endpoint_auth_methods_supported,
|
|
||||||
subject_types_supported,
|
|
||||||
id_token_signing_alg_values_supported,
|
|
||||||
claim_types_supported,
|
|
||||||
claims_supported,
|
|
||||||
jwks
|
|
||||||
)
|
|
||||||
VALUES
|
|
||||||
(
|
|
||||||
:issuer,
|
|
||||||
:authorization_endpoint,
|
|
||||||
:token_endpoint,
|
|
||||||
:revocation_endpoint,
|
|
||||||
:introspection_endpoint,
|
|
||||||
:userinfo_endpoint,
|
|
||||||
:jwks_uri,
|
|
||||||
:response_types_supported,
|
|
||||||
:response_modes_supported,
|
|
||||||
:grant_types_supported,
|
|
||||||
:token_endpoint_auth_methods_supported,
|
|
||||||
:subject_types_supported,
|
|
||||||
:id_token_signing_alg_values_supported,
|
|
||||||
:claim_types_supported,
|
|
||||||
:claims_supported,
|
|
||||||
:jwks
|
|
||||||
);`
|
|
||||||
_, err := tx.NamedExec(sql, &state)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("could not execute transaction: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("could not commit transaction: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetIdentityProvider(path string, issuer string) (*oidc.IdentityProvider, error) {
|
|
||||||
db, err := sqlx.Open("sqlite3", path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not open database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
results := &oidc.IdentityProvider{}
|
|
||||||
err = db.Select(&results, "SELECT * FROM magellan_scanned_ports ORDER BY host ASC, port ASC LIMIT 1;")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not retrieve probes: %v", err)
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetIdentityProviders(path string) ([]oidc.IdentityProvider, error) {
|
|
||||||
db, err := sqlx.Open("sqlite3", path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not open database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
results := []oidc.IdentityProvider{}
|
|
||||||
err = db.Select(&results, "SELECT * FROM magellan_scanned_ports ORDER BY host ASC, port ASC;")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not retrieve probes: %v", err)
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteIdentityProviders(path string, results *[]oidc.IdentityProvider) error {
|
|
||||||
if results == nil {
|
|
||||||
return fmt.Errorf("no probe results found")
|
|
||||||
}
|
|
||||||
db, err := sqlx.Open("sqlite3", path)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("could not open database: %v", err)
|
|
||||||
}
|
|
||||||
tx := db.MustBegin()
|
|
||||||
for _, state := range *results {
|
|
||||||
sql := `DELETE FROM identity_providers WHERE host = :issuer;`
|
|
||||||
_, err := tx.NamedExec(sql, &state)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("could not execute transaction: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("could not commit transaction: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package opaal
|
package flows
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"davidallendj/opaal/internal/oauth"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -15,9 +16,9 @@ type ClientCredentialsFlowEndpoints struct {
|
||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClientCredentials(eps ClientCredentialsFlowEndpoints, client *Client) error {
|
func NewClientCredentialsFlow(eps ClientCredentialsFlowEndpoints, client *oauth.Client) error {
|
||||||
// register a new OAuth 2 client with authorization srever
|
// register a new OAuth 2 client with authorization srever
|
||||||
_, err := client.CreateOAuthClient(eps.Create, nil)
|
_, err := client.CreateOAuthClient(eps.Create)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to register OAuth client: %v", err)
|
return fmt.Errorf("failed to register OAuth client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -37,12 +38,3 @@ func ClientCredentials(eps ClientCredentialsFlowEndpoints, client *Client) error
|
||||||
fmt.Printf("token: %v\n", string(res))
|
fmt.Printf("token: %v\n", string(res))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClientCredentialsWithConfig(config *Config, client *Client) error {
|
|
||||||
eps := ClientCredentialsFlowEndpoints{
|
|
||||||
Create: config.Authorization.RequestUrls.Clients,
|
|
||||||
Authorize: config.Authorization.RequestUrls.Authorize,
|
|
||||||
Token: config.Authorization.RequestUrls.Token,
|
|
||||||
}
|
|
||||||
return ClientCredentials(eps, client)
|
|
||||||
}
|
|
||||||
319
internal/flows/jwt_bearer.go
Normal file
319
internal/flows/jwt_bearer.go
Normal file
|
|
@ -0,0 +1,319 @@
|
||||||
|
package flows
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"davidallendj/opaal/internal/oauth"
|
||||||
|
"davidallendj/opaal/internal/oidc"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/davidallendj/go-utils/cryptox"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jws"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JwtBearerFlowParams struct {
|
||||||
|
AccessToken string
|
||||||
|
IdToken string
|
||||||
|
IdentityProvider *oidc.IdentityProvider
|
||||||
|
TrustedIssuer *oauth.TrustedIssuer
|
||||||
|
Client *oauth.Client
|
||||||
|
Verbose bool
|
||||||
|
KeyPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
type JwtBearerEndpoints struct {
|
||||||
|
TrustedIssuers string
|
||||||
|
Token string
|
||||||
|
Clients string
|
||||||
|
Register string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJwtBearerFlow(eps JwtBearerEndpoints, params JwtBearerFlowParams) (string, error) {
|
||||||
|
// 1. verify that the JWT from the issuer is valid using all keys
|
||||||
|
var (
|
||||||
|
idp = params.IdentityProvider
|
||||||
|
accessToken = params.AccessToken
|
||||||
|
idToken = params.IdToken
|
||||||
|
client = params.Client
|
||||||
|
trustedIssuer = params.TrustedIssuer
|
||||||
|
verbose = params.Verbose
|
||||||
|
)
|
||||||
|
if accessToken != "" {
|
||||||
|
_, err := jws.Verify([]byte(accessToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to verify access token: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if idToken != "" {
|
||||||
|
_, err := jws.Verify([]byte(idToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to verify ID token: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check if we are already registered as a trusted issuer with authorization server...
|
||||||
|
|
||||||
|
// 3.a if not, create a new JWKS (or just JWK) to be verified
|
||||||
|
var (
|
||||||
|
keyPath string = params.KeyPath
|
||||||
|
privateJwk jwk.Key
|
||||||
|
publicJwk jwk.Key
|
||||||
|
)
|
||||||
|
rawPrivateKey, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("failed to read private key...generating a new one.\n")
|
||||||
|
}
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate new RSA key: %v", err)
|
||||||
|
}
|
||||||
|
privateJwk, publicJwk, err = cryptox.GenerateJwkKeyPairFromPrivateKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate JWK pair from private key: %v", err)
|
||||||
|
}
|
||||||
|
// save new key to key path to reuse later
|
||||||
|
b := cryptox.MarshalRSAPrivateKey(privateKey)
|
||||||
|
err = os.WriteFile(keyPath, b, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to write private key to file: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
privateKey, err := cryptox.GenerateRSAPrivateKey(rawPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate RSA key from string: %v", err)
|
||||||
|
}
|
||||||
|
privateJwk, publicJwk, err = cryptox.GenerateJwkKeyPairFromPrivateKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate JWK pair from private key: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
publicJwk.Set("kid", uuid.New().String())
|
||||||
|
publicJwk.Set("use", "sig")
|
||||||
|
|
||||||
|
if err := publicJwk.Validate(); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to validate public JWK: %v", err)
|
||||||
|
}
|
||||||
|
trustedIssuer.PublicKey = publicJwk
|
||||||
|
|
||||||
|
// 3.b ...and then, add opaal's server host as a trusted issuer with JWK
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
||||||
|
}
|
||||||
|
res, err := client.AddTrustedIssuer(
|
||||||
|
eps.TrustedIssuers,
|
||||||
|
trustedIssuer,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to add trusted issuer: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("trusted issuer: %v\n", string(res))
|
||||||
|
// TODO: add trusted issuer to cache if successful
|
||||||
|
|
||||||
|
// 4. create a new JWT based on the claims from the identity provider and sign
|
||||||
|
parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKeySet(idp.KeySet))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse ID token: %v", err)
|
||||||
|
}
|
||||||
|
payload := parsedIdToken.PrivateClaims()
|
||||||
|
payload["iss"] = trustedIssuer.Issuer
|
||||||
|
payload["aud"] = []string{eps.Token}
|
||||||
|
payload["iat"] = time.Now().Unix()
|
||||||
|
payload["nbf"] = time.Now().Unix()
|
||||||
|
payload["exp"] = time.Now().Add(time.Second * 3600).Unix()
|
||||||
|
payload["sub"] = "opaal"
|
||||||
|
payloadJson, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to marshal payload: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("payload: %v\n", string(payloadJson))
|
||||||
|
newJwt, err := jws.Sign(payloadJson, jws.WithJSON(), jws.WithKey(jwa.RS256, privateJwk))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to sign token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. dynamically register new OAuth client and authorize it to make jwt_bearer request
|
||||||
|
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
|
||||||
|
res, err = client.RegisterOAuthClient(eps.Register)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to register client: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%v\n", string(res))
|
||||||
|
|
||||||
|
// extract the client info from response
|
||||||
|
var clientData map[string]any
|
||||||
|
err = json.Unmarshal(res, &clientData)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to unmarshal client data: %v", err)
|
||||||
|
} else {
|
||||||
|
// check for error first
|
||||||
|
errJson := clientData["error"]
|
||||||
|
if errJson == nil {
|
||||||
|
client.Id = clientData["client_id"].(string)
|
||||||
|
client.Secret = clientData["client_secret"].(string)
|
||||||
|
} else {
|
||||||
|
// delete client and try to create again
|
||||||
|
fmt.Printf("Attempting to delete client...\n")
|
||||||
|
err := client.DeleteOAuthClient(eps.Clients)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to delete OAuth client: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Attempting to re-create client...\n")
|
||||||
|
res, err := client.CreateOAuthClient(eps.Clients)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to register client: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%v\n", string(res))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: add OAuth client to cache if successfully
|
||||||
|
|
||||||
|
// authorize the client
|
||||||
|
// fmt.Printf("Attempting to authorize client...\n")
|
||||||
|
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to authorize client: %v", err)
|
||||||
|
// }
|
||||||
|
// fmt.Printf("%v\n", string(res))
|
||||||
|
|
||||||
|
// 6. send JWT to authorization server and receive a access token
|
||||||
|
if eps.Token != "" {
|
||||||
|
fmt.Printf("Fetching access token from authorization server...\n")
|
||||||
|
res, err := client.PerformTokenGrant(eps.Token, string(newJwt))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to fetch access token: %v", err)
|
||||||
|
}
|
||||||
|
// extract token from response if there are no errors
|
||||||
|
var data map[string]any
|
||||||
|
err = json.Unmarshal(res, &data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if data["error"] != nil {
|
||||||
|
return "", fmt.Errorf("the authorization server returned an error (%v): %v", data["error"], data["error_description"])
|
||||||
|
}
|
||||||
|
fmt.Printf("%s\n", res)
|
||||||
|
|
||||||
|
err = json.Unmarshal(res, &data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to unmarshal access token: %v", err)
|
||||||
|
}
|
||||||
|
return data["access_token"].(string), nil
|
||||||
|
} else {
|
||||||
|
return "", fmt.Errorf("token endpoint not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(res), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ForwardToken(eps JwtBearerEndpoints, params JwtBearerFlowParams) error {
|
||||||
|
var (
|
||||||
|
client = params.Client
|
||||||
|
idToken = params.IdToken
|
||||||
|
idp = params.IdentityProvider
|
||||||
|
verbose = params.Verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
// fetch JWKS and add issuer to authentication server to submit ID token
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("Fetching JWKS from authentication server for verification...\n")
|
||||||
|
}
|
||||||
|
err := idp.FetchJwks()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch JWK: %v", err)
|
||||||
|
} else {
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("Successfully retrieved JWK from authentication server.\n\n")
|
||||||
|
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
ti := &oauth.TrustedIssuer{
|
||||||
|
Issuer: idp.Issuer,
|
||||||
|
Subject: "1",
|
||||||
|
ExpiresAt: time.Now().Add(time.Second * 3600),
|
||||||
|
}
|
||||||
|
res, err := client.AddTrustedIssuer(
|
||||||
|
eps.TrustedIssuers,
|
||||||
|
ti,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add trusted issuer: %v", err)
|
||||||
|
}
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("%v\n", string(res))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// try and register a new client with authorization server
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
|
||||||
|
}
|
||||||
|
res, err := client.RegisterOAuthClient(eps.Register)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to register client: %v", err)
|
||||||
|
}
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("%v\n", string(res))
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the client info from response
|
||||||
|
var clientData map[string]any
|
||||||
|
err = json.Unmarshal(res, &clientData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal client data: %v", err)
|
||||||
|
} else {
|
||||||
|
// check for error first
|
||||||
|
errJson := clientData["error"]
|
||||||
|
if errJson == nil {
|
||||||
|
client.Id = clientData["client_id"].(string)
|
||||||
|
client.Secret = clientData["client_secret"].(string)
|
||||||
|
} else {
|
||||||
|
// delete client and create again
|
||||||
|
fmt.Printf("Attempting to delete client...\n")
|
||||||
|
err := client.DeleteOAuthClient(eps.Clients)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete OAuth client: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Attempting to re-create client...\n")
|
||||||
|
res, err := client.CreateOAuthClient(eps.Clients)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to register client: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%v\n", string(res))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// authorize the client
|
||||||
|
// fmt.Printf("Attempting to authorize client...\n")
|
||||||
|
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to authorize client: %v", err)
|
||||||
|
// }
|
||||||
|
// fmt.Printf("%v\n", string(res))
|
||||||
|
|
||||||
|
// use ID token/user info to fetch access token from authentication server
|
||||||
|
if eps.Token != "" {
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("Fetching access token from authorization server...\n")
|
||||||
|
}
|
||||||
|
res, err := client.PerformTokenGrant(eps.Token, idToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch access token: %v", err)
|
||||||
|
}
|
||||||
|
if verbose {
|
||||||
|
fmt.Printf("%s\n", res)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("token endpoint is not set")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package opaal
|
package oauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
|
@ -1,65 +1,41 @@
|
||||||
package opaal
|
package oauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"davidallendj/opaal/internal/oidc"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
"net/url"
|
"net/url"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/davidallendj/go-utils/httpx"
|
"github.com/davidallendj/go-utils/httpx"
|
||||||
"github.com/davidallendj/go-utils/util"
|
"github.com/davidallendj/go-utils/util"
|
||||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
"golang.org/x/net/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (client *Client) AddTrustedIssuer(url string, issuer string, key jwk.Key, subject string, expires time.Duration) ([]byte, error) {
|
type Client struct {
|
||||||
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
|
http.Client
|
||||||
quotedScopes := make([]string, len(client.Scope))
|
Id string `db:"id" yaml:"id"`
|
||||||
for i, s := range client.Scope {
|
Secret string `db:"secret" yaml:"secret"`
|
||||||
quotedScopes[i] = fmt.Sprintf("\"%s\"", s)
|
Name string `db:"name" yaml:"name"`
|
||||||
}
|
Description string `db:"description" yaml:"description"`
|
||||||
jwkstr, err := json.Marshal(key)
|
Issuer string `db:"issuer" yaml:"issuer"`
|
||||||
if err != nil {
|
RegistrationAccessToken string `db:"registration_access_token" yaml:"registration-access-token"`
|
||||||
return nil, fmt.Errorf("failed to marshal JWK: %v", err)
|
RedirectUris []string `db:"redirect_uris" yaml:"redirect-uris"`
|
||||||
}
|
Scope []string `db:"scope" yaml:"scope"`
|
||||||
// NOTE: Can also include "jwks_uri" instead
|
Audience []string `db:"audience" yaml:"audience"`
|
||||||
data := []byte(fmt.Sprintf("{"+
|
FlowId string
|
||||||
"\"allow_any_subject\": false,"+
|
CsrfToken string
|
||||||
"\"issuer\": \"%s\","+
|
|
||||||
"\"subject\": \"%s\","+
|
|
||||||
"\"expires_at\": \"%v\","+
|
|
||||||
"\"jwk\": %v,"+
|
|
||||||
"\"scope\": [ %s ]"+
|
|
||||||
"}", issuer, subject, time.Now().Add(expires).Format(time.RFC3339), string(jwkstr), strings.Join(quotedScopes, ",")))
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
|
|
||||||
// req.Header.Add("X-CSRF-Token", client.CsrfToken.Value)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to make request: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
|
||||||
// req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken))
|
|
||||||
res, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to do request: %v", err)
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
return io.ReadAll(res.Body)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) AddTrustedIssuerWithIdentityProvider(url string, idp *oidc.IdentityProvider, subject string, expires time.Duration) ([]byte, error) {
|
func NewClient() *Client {
|
||||||
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
|
return &Client{}
|
||||||
key, ok := idp.Jwks.Key(0)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("no keys found in key set")
|
|
||||||
}
|
}
|
||||||
return client.AddTrustedIssuer(url, idp.Issuer, key, subject, expires)
|
|
||||||
|
func (client *Client) ClearCookies() {
|
||||||
|
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
|
||||||
|
client.Jar = jar
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) IsOAuthClientRegistered(clientUrl string) (bool, error) {
|
func (client *Client) IsOAuthClientRegistered(clientUrl string) (bool, error) {
|
||||||
|
|
@ -107,9 +83,9 @@ func (client *Client) GetOAuthClient(clientUrl string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) CreateOAuthClient(registerUrl string, audience []string) ([]byte, error) {
|
func (client *Client) CreateOAuthClient(registerUrl string) ([]byte, error) {
|
||||||
// hydra endpoint: POST /clients
|
// hydra endpoint: POST /clients
|
||||||
audience = util.QuoteArrayStrings(audience)
|
audience := util.QuoteArrayStrings(client.Audience)
|
||||||
body := httpx.Body(fmt.Sprintf(`{
|
body := httpx.Body(fmt.Sprintf(`{
|
||||||
"client_id": "%s",
|
"client_id": "%s",
|
||||||
"client_name": "%s",
|
"client_name": "%s",
|
||||||
|
|
@ -151,9 +127,12 @@ func (client *Client) CreateOAuthClient(registerUrl string, audience []string) (
|
||||||
return b, err
|
return b, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) RegisterOAuthClient(registerUrl string, audience []string) ([]byte, error) {
|
func (client *Client) RegisterOAuthClient(registerUrl string) ([]byte, error) {
|
||||||
// hydra endpoint: POST /oauth2/register
|
// hydra endpoint: POST /oauth2/register
|
||||||
audience = util.QuoteArrayStrings(audience)
|
if registerUrl == "" {
|
||||||
|
return nil, fmt.Errorf("no URL provided")
|
||||||
|
}
|
||||||
|
audience := util.QuoteArrayStrings(client.Audience)
|
||||||
body := httpx.Body(fmt.Sprintf(`{
|
body := httpx.Body(fmt.Sprintf(`{
|
||||||
"client_name": "opaal",
|
"client_name": "opaal",
|
||||||
"token_endpoint_auth_method": "client_secret_post",
|
"token_endpoint_auth_method": "client_secret_post",
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package opaal
|
package oauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
99
internal/oauth/trusted.go
Normal file
99
internal/oauth/trusted.go
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/davidallendj/go-utils/httpx"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TrustedIssuer struct {
|
||||||
|
Id string `db:"id" yaml:"id"`
|
||||||
|
AllowAnySubject bool `db:"allow_any_subject" yaml:"allow-any-subject"`
|
||||||
|
ExpiresAt time.Time `db:"expires_at" yaml:"expires-at"`
|
||||||
|
Issuer string `db:"issuer" yaml:"issuer"`
|
||||||
|
PublicKey jwk.Key `db:"public_key" yaml:"public-key"`
|
||||||
|
Scope []string `db:"scope" yaml:"scope"`
|
||||||
|
Subject string `db:"subject" yaml:"subject"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTrustedIssuer() *TrustedIssuer {
|
||||||
|
return &TrustedIssuer{
|
||||||
|
AllowAnySubject: false,
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour),
|
||||||
|
Scope: []string{"openid"},
|
||||||
|
Subject: "1",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti *TrustedIssuer) IsTrustedIssuerValid() bool {
|
||||||
|
err := ti.PublicKey.Validate()
|
||||||
|
return ti.Issuer != "" && err == nil && ti.Subject != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseString(b []byte) (*TrustedIssuer, error) {
|
||||||
|
// take data from JSON to populate fields
|
||||||
|
ti := &TrustedIssuer{}
|
||||||
|
data := map[string]any{}
|
||||||
|
json.Unmarshal(b, &data)
|
||||||
|
return ti, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) ListTrustedIssuers(url string) ([]TrustedIssuer, error) {
|
||||||
|
// hydra endpoint: GET /admin/trust/grants/jwt-bearer/issuers
|
||||||
|
_, b, err := httpx.MakeHttpRequest(url, http.MethodGet, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to make request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshal results into TrustedIssuers objects
|
||||||
|
trustedIssuers := []TrustedIssuer{}
|
||||||
|
err = json.Unmarshal(b, &trustedIssuers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal JSON: %v", err)
|
||||||
|
}
|
||||||
|
return trustedIssuers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) AddTrustedIssuer(url string, ti *TrustedIssuer) ([]byte, error) {
|
||||||
|
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
|
||||||
|
quotedScopes := make([]string, len(client.Scope))
|
||||||
|
for i, s := range client.Scope {
|
||||||
|
quotedScopes[i] = fmt.Sprintf("\"%s\"", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: Can also include "jwks_uri" instead of "jwk"
|
||||||
|
body := map[string]any{
|
||||||
|
"allow_any_subject": ti.AllowAnySubject,
|
||||||
|
"issuer": ti.Issuer,
|
||||||
|
"expires_at": ti.ExpiresAt,
|
||||||
|
"jwk": ti.PublicKey,
|
||||||
|
"scope": client.Scope,
|
||||||
|
}
|
||||||
|
if !ti.AllowAnySubject {
|
||||||
|
body["subject"] = ti.Subject
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(body)
|
||||||
|
fmt.Printf("request: %v\n", string(b))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(b))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to make request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to do request: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
return io.ReadAll(res.Body)
|
||||||
|
}
|
||||||
|
|
@ -1,121 +0,0 @@
|
||||||
package opaal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/middleware"
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/nikolalohinski/gonja/v2"
|
|
||||||
"github.com/nikolalohinski/gonja/v2/exec"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
*http.Server
|
|
||||||
Host string `yaml:"host"`
|
|
||||||
Port int `yaml:"port"`
|
|
||||||
Callback string `yaml:"callback"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServerWithConfig(conf *Config) *Server {
|
|
||||||
host := conf.Server.Host
|
|
||||||
port := conf.Server.Port
|
|
||||||
server := &Server{
|
|
||||||
Server: &http.Server{
|
|
||||||
Addr: fmt.Sprintf("%s:%d", host, port),
|
|
||||||
},
|
|
||||||
Host: host,
|
|
||||||
Port: port,
|
|
||||||
}
|
|
||||||
return server
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) SetListenAddr(host string, port int) {
|
|
||||||
s.Addr = s.GetListenAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) GetListenAddr() string {
|
|
||||||
return fmt.Sprintf("%s:%d", s.Host, s.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) WaitForAuthorizationCode(loginUrl string, callback string) (string, error) {
|
|
||||||
// check if callback is set
|
|
||||||
if callback == "" {
|
|
||||||
callback = "/oidc/callback"
|
|
||||||
}
|
|
||||||
|
|
||||||
var code string
|
|
||||||
r := chi.NewRouter()
|
|
||||||
r.Use(middleware.RedirectSlashes)
|
|
||||||
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
||||||
})
|
|
||||||
r.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// show login page with notice to redirect
|
|
||||||
template, err := gonja.FromFile("pages/index.html")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data := exec.NewContext(map[string]interface{}{
|
|
||||||
"loginUrl": loginUrl,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err = template.Execute(w, data); err != nil { // Prints: Hello Bob!
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
r.HandleFunc(callback, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// get the code from the OIDC provider
|
|
||||||
if r != nil {
|
|
||||||
code = r.URL.Query().Get("code")
|
|
||||||
fmt.Printf("Authorization code: %v\n", code)
|
|
||||||
}
|
|
||||||
http.Redirect(w, r, "/redirect", http.StatusSeeOther)
|
|
||||||
})
|
|
||||||
r.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
err := s.Close()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("failed to close server: %v\n", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
s.Handler = r
|
|
||||||
|
|
||||||
return code, s.ListenAndServe()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Serve(data chan []byte) error {
|
|
||||||
output, ok := <-data
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("failed to receive data")
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Received data: %v\n", string(output))
|
|
||||||
// http.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
|
|
||||||
// })
|
|
||||||
r := chi.NewRouter()
|
|
||||||
r.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fmt.Printf("Serving success page.")
|
|
||||||
successPage, err := os.ReadFile("pages/success.html")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("failed to load success page: %v\n", err)
|
|
||||||
}
|
|
||||||
successPage = []byte(strings.ReplaceAll(string(successPage), "{{access_token}}", string(output)))
|
|
||||||
w.Write(successPage)
|
|
||||||
})
|
|
||||||
r.HandleFunc("/error", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fmt.Printf("Serving error page.")
|
|
||||||
errorPage, err := os.ReadFile("pages/success.html")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("failed to load success page: %v\n", err)
|
|
||||||
}
|
|
||||||
// errorPage = []byte(strings.ReplaceAll(string(errorPage), "{{access_token}}", output))
|
|
||||||
w.Write(errorPage)
|
|
||||||
})
|
|
||||||
|
|
||||||
s.Handler = r
|
|
||||||
return s.ListenAndServe()
|
|
||||||
}
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue