mirror of
https://github.com/davidallendj/opaal.git
synced 2025-12-20 03:27:02 -07:00
Major refactoring and code restructure
This commit is contained in:
parent
72adbe1f0d
commit
6d63211d35
10 changed files with 454 additions and 859 deletions
|
|
@ -1,424 +0,0 @@
|
|||
package opaal
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"davidallendj/opaal/internal/oidc"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/davidallendj/go-utils/util"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
// TODO: change authorization code flow to use these instead
|
||||
type AuthorizationCodeFlowEndpoints struct {
|
||||
Login string
|
||||
Token string
|
||||
Identities string
|
||||
TrustedIssuer string
|
||||
Register string
|
||||
}
|
||||
|
||||
func AuthorizationCodeWithConfig(config *Config, server *Server, client *Client, idp *oidc.IdentityProvider) error {
|
||||
// check preconditions are met
|
||||
err := verifyParams(config, server, client, idp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// build the authorization URL to redirect user for social sign-in
|
||||
state := config.Authentication.Flows["authorization_code"]["state"]
|
||||
var authorizationUrl = client.BuildAuthorizationUrl(idp.Endpoints.Authorization, state)
|
||||
|
||||
// print the authorization URL for sharing
|
||||
fmt.Printf("Login with identity provider:\n\n %s/login\n %s\n\n",
|
||||
server.GetListenAddr(), authorizationUrl,
|
||||
)
|
||||
|
||||
// automatically open browser to initiate login flow (only useful for testing and debugging)
|
||||
if config.Options.OpenBrowser {
|
||||
util.OpenUrl(authorizationUrl)
|
||||
}
|
||||
|
||||
// authorize oauth client and listen for callback from provider
|
||||
fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", server.GetListenAddr())
|
||||
code, err := server.WaitForAuthorizationCode(authorizationUrl, "")
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
fmt.Printf("\n=========================================\nServer closed.\n=========================================\n\n")
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
// start up another server in background to listen for success or failures
|
||||
d := StartListener(server)
|
||||
|
||||
// use code from response and exchange for bearer token (with ID token)
|
||||
bearerToken, err := client.FetchTokenFromAuthenticationServer(
|
||||
code,
|
||||
idp.Endpoints.Token,
|
||||
state,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch token from issuer: %v", err)
|
||||
}
|
||||
// fmt.Printf("%v\n", string(bearerToken))
|
||||
|
||||
// unmarshal data to get id_token and access_token
|
||||
var data map[string]any
|
||||
err = json.Unmarshal([]byte(bearerToken), &data)
|
||||
if err != nil || data == nil {
|
||||
return fmt.Errorf("failed to unmarshal token: %v", err)
|
||||
}
|
||||
|
||||
// make sure we have an ID token
|
||||
if data["id_token"] == nil {
|
||||
return fmt.Errorf("no ID token found...aborting")
|
||||
}
|
||||
|
||||
// extract ID token from bearer as JSON string for easy consumption
|
||||
idToken := data["id_token"].(string)
|
||||
idJwtSegments, err := util.DecodeJwt(idToken)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to parse ID token: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("id_token: %v\n", idToken)
|
||||
if config.Options.DecodeIdToken {
|
||||
if err != nil {
|
||||
fmt.Printf("failed to decode JWT: %v\n", err)
|
||||
} else {
|
||||
for i, segment := range idJwtSegments {
|
||||
// don't print last segment (signatures)
|
||||
if i == len(idJwtSegments)-1 {
|
||||
break
|
||||
}
|
||||
fmt.Printf("%s\n", string(segment))
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// extract the access token to get the scopes
|
||||
accessToken := data["access_token"].(string)
|
||||
accessJwtSegments, err := util.DecodeJwt(accessToken)
|
||||
if err != nil || len(accessJwtSegments) <= 0 {
|
||||
fmt.Printf("failed to parse access token: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("access_token (from identity provider): %v\n", accessToken)
|
||||
if config.Options.DecodeIdToken {
|
||||
if err != nil {
|
||||
fmt.Printf("failed to decode JWT: %v\n", err)
|
||||
} else {
|
||||
for i, segment := range accessJwtSegments {
|
||||
// don't print last segment (signatures)
|
||||
if i == len(accessJwtSegments)-1 {
|
||||
break
|
||||
}
|
||||
fmt.Printf("%s\n", string(segment))
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if !config.Options.ForwardToken {
|
||||
|
||||
// TODO: implement our own JWT to send to Hydra
|
||||
// 1. verify that the JWT from the issuer is valid
|
||||
key, ok := idp.Jwks.Key(0)
|
||||
if !ok {
|
||||
return fmt.Errorf("no key found in key set")
|
||||
}
|
||||
|
||||
parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKey(jwa.RS256, key))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse ID token: %v", err)
|
||||
}
|
||||
_, err = jwt.ParseString(accessToken, jwt.WithKeySet(idp.Jwks))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse access token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jws.Verify([]byte(idToken), jws.WithKeySet(idp.Jwks), jws.WithValidateKey(true))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify JWT: %v", err)
|
||||
}
|
||||
|
||||
// 2. create a new JWKS (or just JWK) to be verified
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate private RSA k-ey: %v", err)
|
||||
}
|
||||
privateJwk, err := jwk.FromRaw(privateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create private JWK: %v", err)
|
||||
}
|
||||
publicJwk, err := jwk.PublicKeyOf(privateJwk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create public JWK: %v", err)
|
||||
}
|
||||
publicJwk.Set("kid", uuid.New().String())
|
||||
|
||||
// 3. add opaal's server host as a trusted issuer with JWK
|
||||
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
||||
res, err := client.AddTrustedIssuer(
|
||||
config.Authorization.RequestUrls.TrustedIssuers,
|
||||
server.Addr,
|
||||
publicJwk,
|
||||
"1",
|
||||
time.Second*3600,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add trusted issuer: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
|
||||
// 4. create a new JWT based on the claims from the identity provider and sign
|
||||
payload := parsedIdToken.PrivateClaims()
|
||||
payload["iss"] = server.Addr
|
||||
payload["aud"] = []string{config.Authorization.RequestUrls.Token}
|
||||
payload["iat"] = time.Now().Unix()
|
||||
payload["nbf"] = time.Now().Unix()
|
||||
payload["exp"] = time.Now().Add(time.Second * 3600).Unix()
|
||||
payload["sub"] = "1"
|
||||
payloadJson, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
newToken, err := jws.Sign(payloadJson, jws.WithJSON(), jws.WithKey(jwa.RS256, privateJwk))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
// sig = rsasha256(b64urlencode(header) + "." + b64urlencode(payload))
|
||||
// signature := util.EncodeBase64() + util.EncodeBase64() +
|
||||
|
||||
// 5. dynamically register new OAuth client and authorize it
|
||||
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
|
||||
res, err = client.RegisterOAuthClient(config.Authorization.RequestUrls.Register, []string{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register client: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
|
||||
// extract the client info from response
|
||||
var clientData map[string]any
|
||||
err = json.Unmarshal(res, &clientData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal client data: %v", err)
|
||||
} else {
|
||||
// check for error first
|
||||
errJson := clientData["error"]
|
||||
if errJson == nil {
|
||||
client.Id = clientData["client_id"].(string)
|
||||
client.Secret = clientData["client_secret"].(string)
|
||||
} else {
|
||||
// delete client and create again
|
||||
fmt.Printf("Attempting to delete client...\n")
|
||||
err := client.DeleteOAuthClient(config.Authorization.RequestUrls.Clients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete OAuth client: %v", err)
|
||||
}
|
||||
fmt.Printf("Attempting to re-create client...\n")
|
||||
res, err := client.CreateOAuthClient(config.Authorization.RequestUrls.Clients, []string{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register client: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
}
|
||||
}
|
||||
|
||||
// authorize the client
|
||||
// fmt.Printf("Attempting to authorize client...\n")
|
||||
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to authorize client: %v", err)
|
||||
// }
|
||||
// fmt.Printf("%v\n", string(res))
|
||||
|
||||
// 6. send JWT to authorization server and receive a access token
|
||||
if config.Authorization.RequestUrls.Token != "" {
|
||||
fmt.Printf("Fetching access token from authorization server...\n")
|
||||
res, err := client.PerformTokenGrant(config.Authorization.RequestUrls.Token, string(newToken))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch access token: %v", err)
|
||||
}
|
||||
fmt.Printf("%s\n", res)
|
||||
}
|
||||
} else {
|
||||
// extract the scope from access token claims
|
||||
// var scope []string
|
||||
// var accessJsonPayload map[string]any
|
||||
// var accessJwtPayload []byte = accessJwtSegments[1]
|
||||
// if accessJsonPayload != nil {
|
||||
// err := json.Unmarshal(accessJwtPayload, &accessJsonPayload)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to unmarshal JWT: %v", err)
|
||||
// }
|
||||
// scope = idJsonPayload["scope"].([]string)
|
||||
// }
|
||||
|
||||
// create a new identity with identity and session manager if url is provided
|
||||
// if config.RequestUrls.Identities != "" {
|
||||
// fmt.Printf("Attempting to create a new identity...\n")
|
||||
// err := client.CreateIdentity(config.RequestUrls.Identities, idToken)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to create new identity: %v", err)
|
||||
// }
|
||||
// _, err = client.FetchIdentities(config.RequestUrls.Identities)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to fetch identities: %v", err)
|
||||
// }
|
||||
// fmt.Printf("Created new identity successfully.\n\n")
|
||||
// }
|
||||
|
||||
// extract the subject from ID token claims
|
||||
var subject string
|
||||
var audience []string
|
||||
var idJsonPayload map[string]any
|
||||
var idJwtPayload []byte = idJwtSegments[1]
|
||||
if idJwtPayload != nil {
|
||||
err := json.Unmarshal(idJwtPayload, &idJsonPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal JWT: %v", err)
|
||||
}
|
||||
subject = idJsonPayload["sub"].(string)
|
||||
audType := reflect.ValueOf(idJsonPayload["aud"])
|
||||
switch audType.Kind() {
|
||||
case reflect.String:
|
||||
audience = append(audience, idJsonPayload["aud"].(string))
|
||||
case reflect.Array:
|
||||
audience = idJsonPayload["aud"].([]string)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to extract subject from ID token claims")
|
||||
}
|
||||
|
||||
// fetch JWKS and add issuer to authentication server to submit ID token
|
||||
fmt.Printf("Fetching JWKS from authentication server for verification...\n")
|
||||
err = idp.FetchJwks()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch JWK: %v", err)
|
||||
} else {
|
||||
fmt.Printf("Successfully retrieved JWK from authentication server.\n\n")
|
||||
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
||||
res, err := client.AddTrustedIssuerWithIdentityProvider(
|
||||
config.Authorization.RequestUrls.TrustedIssuers,
|
||||
idp,
|
||||
subject,
|
||||
time.Duration(1000),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add trusted issuer: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
}
|
||||
|
||||
// add client ID to audience
|
||||
audience = append(audience, client.Id)
|
||||
audience = append(audience, "http://127.0.0.1:4444/oauth2/token")
|
||||
|
||||
// try and register a new client with authorization server
|
||||
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
|
||||
res, err := client.RegisterOAuthClient(config.Authorization.RequestUrls.Register, audience)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register client: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
|
||||
// extract the client info from response
|
||||
var clientData map[string]any
|
||||
err = json.Unmarshal(res, &clientData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal client data: %v", err)
|
||||
} else {
|
||||
// check for error first
|
||||
errJson := clientData["error"]
|
||||
if errJson == nil {
|
||||
client.Id = clientData["client_id"].(string)
|
||||
client.Secret = clientData["client_secret"].(string)
|
||||
} else {
|
||||
// delete client and create again
|
||||
fmt.Printf("Attempting to delete client...\n")
|
||||
err := client.DeleteOAuthClient(config.Authorization.RequestUrls.Clients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete OAuth client: %v", err)
|
||||
}
|
||||
fmt.Printf("Attempting to re-create client...\n")
|
||||
res, err := client.CreateOAuthClient(config.Authorization.RequestUrls.Clients, audience)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register client: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
}
|
||||
}
|
||||
|
||||
// authorize the client
|
||||
// fmt.Printf("Attempting to authorize client...\n")
|
||||
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to authorize client: %v", err)
|
||||
// }
|
||||
// fmt.Printf("%v\n", string(res))
|
||||
|
||||
// use ID token/user info to fetch access token from authentication server
|
||||
if config.Authorization.RequestUrls.Token != "" {
|
||||
fmt.Printf("Fetching access token from authorization server...\n")
|
||||
res, err := client.PerformTokenGrant(config.Authorization.RequestUrls.Token, idToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch access token: %v", err)
|
||||
}
|
||||
fmt.Printf("%s\n", res)
|
||||
}
|
||||
}
|
||||
var access_token []byte
|
||||
d <- access_token
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyParams(config *Config, server *Server, client *Client, idp *oidc.IdentityProvider) error {
|
||||
// make sure we have a valid server and client
|
||||
if server == nil {
|
||||
return fmt.Errorf("server not initialized or valid (server == nil)")
|
||||
}
|
||||
if client == nil {
|
||||
return fmt.Errorf("client not initialized or valid (client == nil)")
|
||||
}
|
||||
if idp == nil {
|
||||
return fmt.Errorf("identity provider not initialized or valid (idp == nil)")
|
||||
}
|
||||
// check if all appropriate parameters are set in config
|
||||
if !HasRequiredConfigParams(config) {
|
||||
return fmt.Errorf("required params not set correctly or missing")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func StartListener(server *Server) chan []byte {
|
||||
d := make(chan []byte)
|
||||
quit := make(chan bool)
|
||||
|
||||
go server.Serve(d)
|
||||
go func() {
|
||||
select {
|
||||
case <-d:
|
||||
fmt.Printf("got access token")
|
||||
quit <- true
|
||||
case <-quit:
|
||||
close(d)
|
||||
close(quit)
|
||||
return
|
||||
default:
|
||||
}
|
||||
}()
|
||||
return d
|
||||
}
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
package opaal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"slices"
|
||||
|
||||
"github.com/davidallendj/go-utils/mathx"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
http.Client
|
||||
Id string `yaml:"id"`
|
||||
Secret string `yaml:"secret"`
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Issuer string `yaml:"issuer"`
|
||||
RegistrationAccessToken string `yaml:"registration-access-token"`
|
||||
RedirectUris []string `yaml:"redirect-uris"`
|
||||
Scope []string `yaml:"scope"`
|
||||
FlowId string
|
||||
CsrfToken string
|
||||
}
|
||||
|
||||
func NewClient() *Client {
|
||||
return &Client{}
|
||||
}
|
||||
|
||||
func NewClientWithConfig(config *Config) *Client {
|
||||
// make sure config is valid
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// make sure we have at least one client
|
||||
clients := config.Authentication.Clients
|
||||
if len(clients) <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// use the first client found by default
|
||||
return &Client{
|
||||
Id: clients[0].Id,
|
||||
Secret: clients[0].Secret,
|
||||
Name: clients[0].Name,
|
||||
Issuer: clients[0].Issuer,
|
||||
Scope: clients[0].Scope,
|
||||
RedirectUris: clients[0].RedirectUris,
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientWithConfigByIndex(config *Config, index int) *Client {
|
||||
size := len(config.Authentication.Clients)
|
||||
index = mathx.Clamp(index, 0, size)
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewClientWithConfigByName(config *Config, name string) *Client {
|
||||
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
|
||||
return c.Name == name
|
||||
})
|
||||
if index >= 0 {
|
||||
return &config.Authentication.Clients[index]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewClientWithConfigByProvider(config *Config, issuer string) *Client {
|
||||
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
|
||||
return c.Issuer == issuer
|
||||
})
|
||||
|
||||
if index >= 0 {
|
||||
return &config.Authentication.Clients[index]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewClientWithConfigById(config *Config, id string) *Client {
|
||||
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
|
||||
return c.Id == id
|
||||
})
|
||||
if index >= 0 {
|
||||
return &config.Authentication.Clients[index]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) ClearCookies() {
|
||||
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
|
||||
client.Jar = jar
|
||||
}
|
||||
|
|
@ -1,156 +0,0 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"davidallendj/opaal/internal/oidc"
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func CreateIdentityProvidersIfNotExists(path string) (*sqlx.DB, error) {
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS identity_providers (
|
||||
issuer TEXT NOT NULL,
|
||||
authorization_endpoint TEXT,
|
||||
token_endpoint TEXT,
|
||||
revocation_endpoint TEXT,
|
||||
introspection_endpoint TEXT,
|
||||
userinfo_endpoint TEXT,
|
||||
jwks_uri TEXT,
|
||||
response_types_supported TEXT,
|
||||
response_modes_supported TEXT,
|
||||
grant_types_supported TEXT,
|
||||
token_endpoint_auth_methods_supported TEXT,
|
||||
subject_types_supported TEXT,
|
||||
id_token_signing_alg_values_supported TEXT,
|
||||
claim_types_supported TEXT,
|
||||
claims_supported TEXT,
|
||||
jwks TEXT,
|
||||
|
||||
PRIMARY KEY (issuer)
|
||||
);
|
||||
`
|
||||
db, err := sqlx.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open database: %v", err)
|
||||
}
|
||||
db.MustExec(schema)
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func InsertIdentityProviders(path string, providers *[]oidc.IdentityProvider) error {
|
||||
if providers == nil {
|
||||
return fmt.Errorf("states == nil")
|
||||
}
|
||||
|
||||
// create database if it doesn't already exist
|
||||
db, err := CreateIdentityProvidersIfNotExists(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// insert all probe states into db
|
||||
tx := db.MustBegin()
|
||||
for _, state := range *providers {
|
||||
sql := `INSERT OR REPLACE INTO identity_providers
|
||||
(
|
||||
issuer,
|
||||
authorization_endpoint,
|
||||
token_endpoint,
|
||||
revocation_endpoint,
|
||||
introspection_endpoint,
|
||||
userinfo_endpoint,
|
||||
jwks_uri,
|
||||
response_types_supported,
|
||||
response_modes_supported,
|
||||
grant_types_supported,
|
||||
token_endpoint_auth_methods_supported,
|
||||
subject_types_supported,
|
||||
id_token_signing_alg_values_supported,
|
||||
claim_types_supported,
|
||||
claims_supported,
|
||||
jwks
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
:issuer,
|
||||
:authorization_endpoint,
|
||||
:token_endpoint,
|
||||
:revocation_endpoint,
|
||||
:introspection_endpoint,
|
||||
:userinfo_endpoint,
|
||||
:jwks_uri,
|
||||
:response_types_supported,
|
||||
:response_modes_supported,
|
||||
:grant_types_supported,
|
||||
:token_endpoint_auth_methods_supported,
|
||||
:subject_types_supported,
|
||||
:id_token_signing_alg_values_supported,
|
||||
:claim_types_supported,
|
||||
:claims_supported,
|
||||
:jwks
|
||||
);`
|
||||
_, err := tx.NamedExec(sql, &state)
|
||||
if err != nil {
|
||||
fmt.Printf("could not execute transaction: %v\n", err)
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not commit transaction: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetIdentityProvider(path string, issuer string) (*oidc.IdentityProvider, error) {
|
||||
db, err := sqlx.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open database: %v", err)
|
||||
}
|
||||
|
||||
results := &oidc.IdentityProvider{}
|
||||
err = db.Select(&results, "SELECT * FROM magellan_scanned_ports ORDER BY host ASC, port ASC LIMIT 1;")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not retrieve probes: %v", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func GetIdentityProviders(path string) ([]oidc.IdentityProvider, error) {
|
||||
db, err := sqlx.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open database: %v", err)
|
||||
}
|
||||
|
||||
results := []oidc.IdentityProvider{}
|
||||
err = db.Select(&results, "SELECT * FROM magellan_scanned_ports ORDER BY host ASC, port ASC;")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not retrieve probes: %v", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func DeleteIdentityProviders(path string, results *[]oidc.IdentityProvider) error {
|
||||
if results == nil {
|
||||
return fmt.Errorf("no probe results found")
|
||||
}
|
||||
db, err := sqlx.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open database: %v", err)
|
||||
}
|
||||
tx := db.MustBegin()
|
||||
for _, state := range *results {
|
||||
sql := `DELETE FROM identity_providers WHERE host = :issuer;`
|
||||
_, err := tx.NamedExec(sql, &state)
|
||||
if err != nil {
|
||||
fmt.Printf("could not execute transaction: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not commit transaction: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
package opaal
|
||||
package flows
|
||||
|
||||
import (
|
||||
"davidallendj/opaal/internal/oauth"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
|
|
@ -15,9 +16,9 @@ type ClientCredentialsFlowEndpoints struct {
|
|||
Token string
|
||||
}
|
||||
|
||||
func ClientCredentials(eps ClientCredentialsFlowEndpoints, client *Client) error {
|
||||
func NewClientCredentialsFlow(eps ClientCredentialsFlowEndpoints, client *oauth.Client) error {
|
||||
// register a new OAuth 2 client with authorization srever
|
||||
_, err := client.CreateOAuthClient(eps.Create, nil)
|
||||
_, err := client.CreateOAuthClient(eps.Create)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register OAuth client: %v", err)
|
||||
}
|
||||
|
|
@ -37,12 +38,3 @@ func ClientCredentials(eps ClientCredentialsFlowEndpoints, client *Client) error
|
|||
fmt.Printf("token: %v\n", string(res))
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClientCredentialsWithConfig(config *Config, client *Client) error {
|
||||
eps := ClientCredentialsFlowEndpoints{
|
||||
Create: config.Authorization.RequestUrls.Clients,
|
||||
Authorize: config.Authorization.RequestUrls.Authorize,
|
||||
Token: config.Authorization.RequestUrls.Token,
|
||||
}
|
||||
return ClientCredentials(eps, client)
|
||||
}
|
||||
319
internal/flows/jwt_bearer.go
Normal file
319
internal/flows/jwt_bearer.go
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
package flows
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"davidallendj/opaal/internal/oauth"
|
||||
"davidallendj/opaal/internal/oidc"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/davidallendj/go-utils/cryptox"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
type JwtBearerFlowParams struct {
|
||||
AccessToken string
|
||||
IdToken string
|
||||
IdentityProvider *oidc.IdentityProvider
|
||||
TrustedIssuer *oauth.TrustedIssuer
|
||||
Client *oauth.Client
|
||||
Verbose bool
|
||||
KeyPath string
|
||||
}
|
||||
|
||||
type JwtBearerEndpoints struct {
|
||||
TrustedIssuers string
|
||||
Token string
|
||||
Clients string
|
||||
Register string
|
||||
}
|
||||
|
||||
func NewJwtBearerFlow(eps JwtBearerEndpoints, params JwtBearerFlowParams) (string, error) {
|
||||
// 1. verify that the JWT from the issuer is valid using all keys
|
||||
var (
|
||||
idp = params.IdentityProvider
|
||||
accessToken = params.AccessToken
|
||||
idToken = params.IdToken
|
||||
client = params.Client
|
||||
trustedIssuer = params.TrustedIssuer
|
||||
verbose = params.Verbose
|
||||
)
|
||||
if accessToken != "" {
|
||||
_, err := jws.Verify([]byte(accessToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to verify access token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if idToken != "" {
|
||||
_, err := jws.Verify([]byte(idToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to verify ID token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check if we are already registered as a trusted issuer with authorization server...
|
||||
|
||||
// 3.a if not, create a new JWKS (or just JWK) to be verified
|
||||
var (
|
||||
keyPath string = params.KeyPath
|
||||
privateJwk jwk.Key
|
||||
publicJwk jwk.Key
|
||||
)
|
||||
rawPrivateKey, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
fmt.Printf("failed to read private key...generating a new one.\n")
|
||||
}
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate new RSA key: %v", err)
|
||||
}
|
||||
privateJwk, publicJwk, err = cryptox.GenerateJwkKeyPairFromPrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate JWK pair from private key: %v", err)
|
||||
}
|
||||
// save new key to key path to reuse later
|
||||
b := cryptox.MarshalRSAPrivateKey(privateKey)
|
||||
err = os.WriteFile(keyPath, b, os.ModePerm)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to write private key to file: %v", err)
|
||||
}
|
||||
} else {
|
||||
privateKey, err := cryptox.GenerateRSAPrivateKey(rawPrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate RSA key from string: %v", err)
|
||||
}
|
||||
privateJwk, publicJwk, err = cryptox.GenerateJwkKeyPairFromPrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate JWK pair from private key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
publicJwk.Set("kid", uuid.New().String())
|
||||
publicJwk.Set("use", "sig")
|
||||
|
||||
if err := publicJwk.Validate(); err != nil {
|
||||
return "", fmt.Errorf("failed to validate public JWK: %v", err)
|
||||
}
|
||||
trustedIssuer.PublicKey = publicJwk
|
||||
|
||||
// 3.b ...and then, add opaal's server host as a trusted issuer with JWK
|
||||
if verbose {
|
||||
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
||||
}
|
||||
res, err := client.AddTrustedIssuer(
|
||||
eps.TrustedIssuers,
|
||||
trustedIssuer,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to add trusted issuer: %v", err)
|
||||
}
|
||||
fmt.Printf("trusted issuer: %v\n", string(res))
|
||||
// TODO: add trusted issuer to cache if successful
|
||||
|
||||
// 4. create a new JWT based on the claims from the identity provider and sign
|
||||
parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKeySet(idp.KeySet))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse ID token: %v", err)
|
||||
}
|
||||
payload := parsedIdToken.PrivateClaims()
|
||||
payload["iss"] = trustedIssuer.Issuer
|
||||
payload["aud"] = []string{eps.Token}
|
||||
payload["iat"] = time.Now().Unix()
|
||||
payload["nbf"] = time.Now().Unix()
|
||||
payload["exp"] = time.Now().Add(time.Second * 3600).Unix()
|
||||
payload["sub"] = "opaal"
|
||||
payloadJson, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
fmt.Printf("payload: %v\n", string(payloadJson))
|
||||
newJwt, err := jws.Sign(payloadJson, jws.WithJSON(), jws.WithKey(jwa.RS256, privateJwk))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
// 5. dynamically register new OAuth client and authorize it to make jwt_bearer request
|
||||
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
|
||||
res, err = client.RegisterOAuthClient(eps.Register)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to register client: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
|
||||
// extract the client info from response
|
||||
var clientData map[string]any
|
||||
err = json.Unmarshal(res, &clientData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal client data: %v", err)
|
||||
} else {
|
||||
// check for error first
|
||||
errJson := clientData["error"]
|
||||
if errJson == nil {
|
||||
client.Id = clientData["client_id"].(string)
|
||||
client.Secret = clientData["client_secret"].(string)
|
||||
} else {
|
||||
// delete client and try to create again
|
||||
fmt.Printf("Attempting to delete client...\n")
|
||||
err := client.DeleteOAuthClient(eps.Clients)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to delete OAuth client: %v", err)
|
||||
}
|
||||
fmt.Printf("Attempting to re-create client...\n")
|
||||
res, err := client.CreateOAuthClient(eps.Clients)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to register client: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
}
|
||||
}
|
||||
// TODO: add OAuth client to cache if successfully
|
||||
|
||||
// authorize the client
|
||||
// fmt.Printf("Attempting to authorize client...\n")
|
||||
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to authorize client: %v", err)
|
||||
// }
|
||||
// fmt.Printf("%v\n", string(res))
|
||||
|
||||
// 6. send JWT to authorization server and receive a access token
|
||||
if eps.Token != "" {
|
||||
fmt.Printf("Fetching access token from authorization server...\n")
|
||||
res, err := client.PerformTokenGrant(eps.Token, string(newJwt))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch access token: %v", err)
|
||||
}
|
||||
// extract token from response if there are no errors
|
||||
var data map[string]any
|
||||
err = json.Unmarshal(res, &data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if data["error"] != nil {
|
||||
return "", fmt.Errorf("the authorization server returned an error (%v): %v", data["error"], data["error_description"])
|
||||
}
|
||||
fmt.Printf("%s\n", res)
|
||||
|
||||
err = json.Unmarshal(res, &data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal access token: %v", err)
|
||||
}
|
||||
return data["access_token"].(string), nil
|
||||
} else {
|
||||
return "", fmt.Errorf("token endpoint not set")
|
||||
}
|
||||
|
||||
return string(res), nil
|
||||
}
|
||||
|
||||
func ForwardToken(eps JwtBearerEndpoints, params JwtBearerFlowParams) error {
|
||||
var (
|
||||
client = params.Client
|
||||
idToken = params.IdToken
|
||||
idp = params.IdentityProvider
|
||||
verbose = params.Verbose
|
||||
)
|
||||
|
||||
// fetch JWKS and add issuer to authentication server to submit ID token
|
||||
if verbose {
|
||||
fmt.Printf("Fetching JWKS from authentication server for verification...\n")
|
||||
}
|
||||
err := idp.FetchJwks()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch JWK: %v", err)
|
||||
} else {
|
||||
if verbose {
|
||||
fmt.Printf("Successfully retrieved JWK from authentication server.\n\n")
|
||||
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
||||
}
|
||||
|
||||
ti := &oauth.TrustedIssuer{
|
||||
Issuer: idp.Issuer,
|
||||
Subject: "1",
|
||||
ExpiresAt: time.Now().Add(time.Second * 3600),
|
||||
}
|
||||
res, err := client.AddTrustedIssuer(
|
||||
eps.TrustedIssuers,
|
||||
ti,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add trusted issuer: %v", err)
|
||||
}
|
||||
if verbose {
|
||||
fmt.Printf("%v\n", string(res))
|
||||
}
|
||||
}
|
||||
|
||||
// try and register a new client with authorization server
|
||||
if verbose {
|
||||
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
|
||||
}
|
||||
res, err := client.RegisterOAuthClient(eps.Register)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register client: %v", err)
|
||||
}
|
||||
if verbose {
|
||||
fmt.Printf("%v\n", string(res))
|
||||
}
|
||||
|
||||
// extract the client info from response
|
||||
var clientData map[string]any
|
||||
err = json.Unmarshal(res, &clientData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal client data: %v", err)
|
||||
} else {
|
||||
// check for error first
|
||||
errJson := clientData["error"]
|
||||
if errJson == nil {
|
||||
client.Id = clientData["client_id"].(string)
|
||||
client.Secret = clientData["client_secret"].(string)
|
||||
} else {
|
||||
// delete client and create again
|
||||
fmt.Printf("Attempting to delete client...\n")
|
||||
err := client.DeleteOAuthClient(eps.Clients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete OAuth client: %v", err)
|
||||
}
|
||||
fmt.Printf("Attempting to re-create client...\n")
|
||||
res, err := client.CreateOAuthClient(eps.Clients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register client: %v", err)
|
||||
}
|
||||
fmt.Printf("%v\n", string(res))
|
||||
}
|
||||
}
|
||||
|
||||
// authorize the client
|
||||
// fmt.Printf("Attempting to authorize client...\n")
|
||||
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to authorize client: %v", err)
|
||||
// }
|
||||
// fmt.Printf("%v\n", string(res))
|
||||
|
||||
// use ID token/user info to fetch access token from authentication server
|
||||
if eps.Token != "" {
|
||||
if verbose {
|
||||
fmt.Printf("Fetching access token from authorization server...\n")
|
||||
}
|
||||
res, err := client.PerformTokenGrant(eps.Token, idToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch access token: %v", err)
|
||||
}
|
||||
if verbose {
|
||||
fmt.Printf("%s\n", res)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("token endpoint is not set")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package opaal
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
|
@ -1,65 +1,41 @@
|
|||
package opaal
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"davidallendj/opaal/internal/oidc"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/davidallendj/go-utils/httpx"
|
||||
"github.com/davidallendj/go-utils/util"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
)
|
||||
|
||||
func (client *Client) AddTrustedIssuer(url string, issuer string, key jwk.Key, subject string, expires time.Duration) ([]byte, error) {
|
||||
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
|
||||
quotedScopes := make([]string, len(client.Scope))
|
||||
for i, s := range client.Scope {
|
||||
quotedScopes[i] = fmt.Sprintf("\"%s\"", s)
|
||||
}
|
||||
jwkstr, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal JWK: %v", err)
|
||||
}
|
||||
// NOTE: Can also include "jwks_uri" instead
|
||||
data := []byte(fmt.Sprintf("{"+
|
||||
"\"allow_any_subject\": false,"+
|
||||
"\"issuer\": \"%s\","+
|
||||
"\"subject\": \"%s\","+
|
||||
"\"expires_at\": \"%v\","+
|
||||
"\"jwk\": %v,"+
|
||||
"\"scope\": [ %s ]"+
|
||||
"}", issuer, subject, time.Now().Add(expires).Format(time.RFC3339), string(jwkstr), strings.Join(quotedScopes, ",")))
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
|
||||
// req.Header.Add("X-CSRF-Token", client.CsrfToken.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
// req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken))
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to do request: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
return io.ReadAll(res.Body)
|
||||
type Client struct {
|
||||
http.Client
|
||||
Id string `db:"id" yaml:"id"`
|
||||
Secret string `db:"secret" yaml:"secret"`
|
||||
Name string `db:"name" yaml:"name"`
|
||||
Description string `db:"description" yaml:"description"`
|
||||
Issuer string `db:"issuer" yaml:"issuer"`
|
||||
RegistrationAccessToken string `db:"registration_access_token" yaml:"registration-access-token"`
|
||||
RedirectUris []string `db:"redirect_uris" yaml:"redirect-uris"`
|
||||
Scope []string `db:"scope" yaml:"scope"`
|
||||
Audience []string `db:"audience" yaml:"audience"`
|
||||
FlowId string
|
||||
CsrfToken string
|
||||
}
|
||||
|
||||
func (client *Client) AddTrustedIssuerWithIdentityProvider(url string, idp *oidc.IdentityProvider, subject string, expires time.Duration) ([]byte, error) {
|
||||
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
|
||||
key, ok := idp.Jwks.Key(0)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no keys found in key set")
|
||||
func NewClient() *Client {
|
||||
return &Client{}
|
||||
}
|
||||
return client.AddTrustedIssuer(url, idp.Issuer, key, subject, expires)
|
||||
|
||||
func (client *Client) ClearCookies() {
|
||||
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
|
||||
client.Jar = jar
|
||||
}
|
||||
|
||||
func (client *Client) IsOAuthClientRegistered(clientUrl string) (bool, error) {
|
||||
|
|
@ -107,9 +83,9 @@ func (client *Client) GetOAuthClient(clientUrl string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) CreateOAuthClient(registerUrl string, audience []string) ([]byte, error) {
|
||||
func (client *Client) CreateOAuthClient(registerUrl string) ([]byte, error) {
|
||||
// hydra endpoint: POST /clients
|
||||
audience = util.QuoteArrayStrings(audience)
|
||||
audience := util.QuoteArrayStrings(client.Audience)
|
||||
body := httpx.Body(fmt.Sprintf(`{
|
||||
"client_id": "%s",
|
||||
"client_name": "%s",
|
||||
|
|
@ -151,9 +127,12 @@ func (client *Client) CreateOAuthClient(registerUrl string, audience []string) (
|
|||
return b, err
|
||||
}
|
||||
|
||||
func (client *Client) RegisterOAuthClient(registerUrl string, audience []string) ([]byte, error) {
|
||||
func (client *Client) RegisterOAuthClient(registerUrl string) ([]byte, error) {
|
||||
// hydra endpoint: POST /oauth2/register
|
||||
audience = util.QuoteArrayStrings(audience)
|
||||
if registerUrl == "" {
|
||||
return nil, fmt.Errorf("no URL provided")
|
||||
}
|
||||
audience := util.QuoteArrayStrings(client.Audience)
|
||||
body := httpx.Body(fmt.Sprintf(`{
|
||||
"client_name": "opaal",
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package opaal
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
99
internal/oauth/trusted.go
Normal file
99
internal/oauth/trusted.go
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
package oauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/davidallendj/go-utils/httpx"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
)
|
||||
|
||||
type TrustedIssuer struct {
|
||||
Id string `db:"id" yaml:"id"`
|
||||
AllowAnySubject bool `db:"allow_any_subject" yaml:"allow-any-subject"`
|
||||
ExpiresAt time.Time `db:"expires_at" yaml:"expires-at"`
|
||||
Issuer string `db:"issuer" yaml:"issuer"`
|
||||
PublicKey jwk.Key `db:"public_key" yaml:"public-key"`
|
||||
Scope []string `db:"scope" yaml:"scope"`
|
||||
Subject string `db:"subject" yaml:"subject"`
|
||||
}
|
||||
|
||||
func NewTrustedIssuer() *TrustedIssuer {
|
||||
return &TrustedIssuer{
|
||||
AllowAnySubject: false,
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
Scope: []string{"openid"},
|
||||
Subject: "1",
|
||||
}
|
||||
}
|
||||
|
||||
func (ti *TrustedIssuer) IsTrustedIssuerValid() bool {
|
||||
err := ti.PublicKey.Validate()
|
||||
return ti.Issuer != "" && err == nil && ti.Subject != ""
|
||||
}
|
||||
|
||||
func ParseString(b []byte) (*TrustedIssuer, error) {
|
||||
// take data from JSON to populate fields
|
||||
ti := &TrustedIssuer{}
|
||||
data := map[string]any{}
|
||||
json.Unmarshal(b, &data)
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
func (client *Client) ListTrustedIssuers(url string) ([]TrustedIssuer, error) {
|
||||
// hydra endpoint: GET /admin/trust/grants/jwt-bearer/issuers
|
||||
_, b, err := httpx.MakeHttpRequest(url, http.MethodGet, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %v", err)
|
||||
}
|
||||
|
||||
// unmarshal results into TrustedIssuers objects
|
||||
trustedIssuers := []TrustedIssuer{}
|
||||
err = json.Unmarshal(b, &trustedIssuers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
return trustedIssuers, nil
|
||||
}
|
||||
|
||||
func (client *Client) AddTrustedIssuer(url string, ti *TrustedIssuer) ([]byte, error) {
|
||||
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
|
||||
quotedScopes := make([]string, len(client.Scope))
|
||||
for i, s := range client.Scope {
|
||||
quotedScopes[i] = fmt.Sprintf("\"%s\"", s)
|
||||
}
|
||||
|
||||
// NOTE: Can also include "jwks_uri" instead of "jwk"
|
||||
body := map[string]any{
|
||||
"allow_any_subject": ti.AllowAnySubject,
|
||||
"issuer": ti.Issuer,
|
||||
"expires_at": ti.ExpiresAt,
|
||||
"jwk": ti.PublicKey,
|
||||
"scope": client.Scope,
|
||||
}
|
||||
if !ti.AllowAnySubject {
|
||||
body["subject"] = ti.Subject
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
fmt.Printf("request: %v\n", string(b))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(b))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to do request: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
return io.ReadAll(res.Body)
|
||||
}
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
package opaal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/nikolalohinski/gonja/v2"
|
||||
"github.com/nikolalohinski/gonja/v2/exec"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
*http.Server
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
Callback string `yaml:"callback"`
|
||||
}
|
||||
|
||||
func NewServerWithConfig(conf *Config) *Server {
|
||||
host := conf.Server.Host
|
||||
port := conf.Server.Port
|
||||
server := &Server{
|
||||
Server: &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", host, port),
|
||||
},
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
return server
|
||||
}
|
||||
|
||||
func (s *Server) SetListenAddr(host string, port int) {
|
||||
s.Addr = s.GetListenAddr()
|
||||
}
|
||||
|
||||
func (s *Server) GetListenAddr() string {
|
||||
return fmt.Sprintf("%s:%d", s.Host, s.Port)
|
||||
}
|
||||
|
||||
func (s *Server) WaitForAuthorizationCode(loginUrl string, callback string) (string, error) {
|
||||
// check if callback is set
|
||||
if callback == "" {
|
||||
callback = "/oidc/callback"
|
||||
}
|
||||
|
||||
var code string
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.RedirectSlashes)
|
||||
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
})
|
||||
r.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
// show login page with notice to redirect
|
||||
template, err := gonja.FromFile("pages/index.html")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
data := exec.NewContext(map[string]interface{}{
|
||||
"loginUrl": loginUrl,
|
||||
})
|
||||
|
||||
if err = template.Execute(w, data); err != nil { // Prints: Hello Bob!
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
r.HandleFunc(callback, func(w http.ResponseWriter, r *http.Request) {
|
||||
// get the code from the OIDC provider
|
||||
if r != nil {
|
||||
code = r.URL.Query().Get("code")
|
||||
fmt.Printf("Authorization code: %v\n", code)
|
||||
}
|
||||
http.Redirect(w, r, "/redirect", http.StatusSeeOther)
|
||||
})
|
||||
r.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.Close()
|
||||
if err != nil {
|
||||
fmt.Printf("failed to close server: %v\n", err)
|
||||
}
|
||||
})
|
||||
s.Handler = r
|
||||
|
||||
return code, s.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) Serve(data chan []byte) error {
|
||||
output, ok := <-data
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to receive data")
|
||||
}
|
||||
|
||||
fmt.Printf("Received data: %v\n", string(output))
|
||||
// http.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// })
|
||||
r := chi.NewRouter()
|
||||
r.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Printf("Serving success page.")
|
||||
successPage, err := os.ReadFile("pages/success.html")
|
||||
if err != nil {
|
||||
fmt.Printf("failed to load success page: %v\n", err)
|
||||
}
|
||||
successPage = []byte(strings.ReplaceAll(string(successPage), "{{access_token}}", string(output)))
|
||||
w.Write(successPage)
|
||||
})
|
||||
r.HandleFunc("/error", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Printf("Serving error page.")
|
||||
errorPage, err := os.ReadFile("pages/success.html")
|
||||
if err != nil {
|
||||
fmt.Printf("failed to load success page: %v\n", err)
|
||||
}
|
||||
// errorPage = []byte(strings.ReplaceAll(string(errorPage), "{{access_token}}", output))
|
||||
w.Write(errorPage)
|
||||
})
|
||||
|
||||
s.Handler = r
|
||||
return s.ListenAndServe()
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue