diff --git a/cmd/login.go b/cmd/login.go
index 09b0b58..855a0e8 100644
--- a/cmd/login.go
+++ b/cmd/login.go
@@ -2,12 +2,9 @@ package cmd
import (
opaal "davidallendj/opaal/internal"
- cache "davidallendj/opaal/internal/cache/sqlite"
"davidallendj/opaal/internal/oauth"
- "davidallendj/opaal/internal/oidc"
"fmt"
"os"
- "slices"
"github.com/spf13/cobra"
)
@@ -25,73 +22,68 @@ var loginCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
for {
// try and find client with valid identity provider config
- var provider *oidc.IdentityProvider
- if target != "" {
- // only try to use client with name give
- index := slices.IndexFunc(config.Authentication.Clients, func(c oauth.Client) bool {
- return target == c.Name
- })
- if index < 0 {
- fmt.Printf("could not find the target client listed by name")
- os.Exit(1)
- }
- client := config.Authentication.Clients[index]
- _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer)
- if err != nil {
+ // var provider *oidc.IdentityProvider
+ // if target != "" {
+ // // only try to use client with name give
+ // index := slices.IndexFunc(config.Authentication.Clients, func(c oauth.Client) bool {
+ // return target == c.Name
+ // })
+ // if index < 0 {
+ // fmt.Printf("could not find the target client listed by name")
+ // os.Exit(1)
+ // }
+ // client := config.Authentication.Clients[index]
+ // _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer)
+ // if err != nil {
- }
+ // }
- } else if targetIndex >= 0 {
- // only try to use client by index
- targetCount := len(config.Authentication.Clients) - 1
- if targetIndex > targetCount {
- fmt.Printf("target index out of range (found %d)", targetCount)
- }
- client := config.Authentication.Clients[targetIndex]
- _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer)
- if err != nil {
+ // } else if targetIndex >= 0 {
+ // // only try to use client by index
+ // targetCount := len(config.Authentication.Clients) - 1
+ // if targetIndex > targetCount {
+ // fmt.Printf("target index out of range (found %d)", targetCount)
+ // }
+ // client := config.Authentication.Clients[targetIndex]
+ // _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer)
+ // if err != nil {
- }
- } else {
- for _, c := range config.Authentication.Clients {
- // try to get identity provider info locally first
- _, err := cache.GetIdentityProvider(config.Options.CachePath, c.Issuer)
- if err != nil && !config.Options.CacheOnly {
- fmt.Printf("fetching config from issuer: %v\n", c.Issuer)
- // try to get info remotely by fetching
- provider, err = oidc.FetchServerConfig(c.Issuer)
- if err != nil {
- fmt.Printf("failed to fetch server config: %v\n", err)
- continue
- }
- client = c
- // fetch the provider's JWKS
- err := provider.FetchJwks()
- if err != nil {
- fmt.Printf("failed to fetch JWKS: %v\n", err)
- }
- break
- }
- // only test the first if --run-all flag is not set
- if !config.Authentication.TestAllClients {
- fmt.Printf("stopping after first test...\n\n\n")
- break
- }
- }
- }
+ // }
+ // } else {
+ // for _, c := range config.Authentication.Clients {
+ // // try to get identity provider info locally first
+ // _, err := cache.GetIdentityProvider(config.Options.CachePath, c.Issuer)
+ // if err != nil && !config.Options.CacheOnly {
+ // fmt.Printf("fetching config from issuer: %v\n", c.Issuer)
+ // // try to get info remotely by fetching
+ // provider, err = oidc.FetchServerConfig(c.Issuer)
+ // if err != nil {
+ // fmt.Printf("failed to fetch server config: %v\n", err)
+ // continue
+ // }
+ // client = c
+ // // fetch the provider's JWKS
+ // err := provider.FetchJwks()
+ // if err != nil {
+ // fmt.Printf("failed to fetch JWKS: %v\n", err)
+ // }
+ // break
+ // }
+ // // only test the first if --run-all flag is not set
+ // if !config.Authentication.TestAllClients {
+ // fmt.Printf("stopping after first test...\n\n\n")
+ // break
+ // }
+ // }
+ // }
- if provider == nil {
- fmt.Printf("failed to retrieve provider config\n")
- os.Exit(1)
- }
-
- // use clients to make SSO buttons that
- for _, client := range config.Authentication.Clients {
- MakeButton()
- }
+ // if provider == nil {
+ // fmt.Printf("failed to retrieve provider config\n")
+ // os.Exit(1)
+ // }
// start the listener
- err := opaal.Login(&config, &client, provider)
+ err := opaal.Login(&config)
if err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
diff --git a/internal/flows/jwt_bearer.go b/internal/flows/jwt_bearer.go
index 652944b..b73ebdc 100644
--- a/internal/flows/jwt_bearer.go
+++ b/internal/flows/jwt_bearer.go
@@ -4,7 +4,6 @@ import (
"crypto/rand"
"crypto/rsa"
"davidallendj/opaal/internal/oauth"
- "davidallendj/opaal/internal/oidc"
"encoding/json"
"fmt"
"os"
@@ -19,14 +18,14 @@ import (
)
type JwtBearerFlowParams struct {
- AccessToken string
- IdToken string
- IdentityProvider *oidc.IdentityProvider
- TrustedIssuer *oauth.TrustedIssuer
- Client *oauth.Client
- Refresh bool
- Verbose bool
- KeyPath string
+ AccessToken string
+ IdToken string
+ // IdentityProvider *oidc.IdentityProvider
+ TrustedIssuer *oauth.TrustedIssuer
+ Client *oauth.Client
+ Refresh bool
+ Verbose bool
+ KeyPath string
}
type JwtBearerFlowEndpoints struct {
@@ -39,22 +38,27 @@ type JwtBearerFlowEndpoints struct {
func NewJwtBearerFlow(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) (string, error) {
// 1. verify that the JWT from the issuer is valid using all keys
var (
- idp = params.IdentityProvider
+ // idp = params.IdentityProvider
accessToken = params.AccessToken
idToken = params.IdToken
client = params.Client
trustedIssuer = params.TrustedIssuer
verbose = params.Verbose
)
+
+ // pre-condition checks to make sure certain variables are set
+ if client == nil {
+ return "", fmt.Errorf("invalid client (client is nil)")
+ }
if accessToken != "" {
- _, err := jws.Verify([]byte(accessToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true))
+ _, err := jws.Verify([]byte(accessToken), jws.WithKeySet(client.Provider.KeySet), jws.WithValidateKey(true))
if err != nil {
return "", fmt.Errorf("failed to verify access token: %v", err)
}
}
if idToken != "" {
- _, err := jws.Verify([]byte(idToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true))
+ _, err := jws.Verify([]byte(idToken), jws.WithKeySet(client.Provider.KeySet), jws.WithValidateKey(true))
if err != nil {
return "", fmt.Errorf("failed to verify ID token: %v", err)
}
@@ -126,7 +130,7 @@ func NewJwtBearerFlow(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) (s
// TODO: add trusted issuer to cache if successful
// 4. create a new JWT based on the claims from the identity provider and sign
- parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKeySet(idp.KeySet))
+ parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKeySet(client.Provider.KeySet))
if err != nil {
return "", fmt.Errorf("failed to parse ID token: %v", err)
}
@@ -242,7 +246,7 @@ func ForwardToken(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) error
var (
client = params.Client
idToken = params.IdToken
- idp = params.IdentityProvider
+ // idp = params.IdentityProvider
verbose = params.Verbose
)
@@ -250,7 +254,7 @@ func ForwardToken(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) error
if verbose {
fmt.Printf("Fetching JWKS from authentication server for verification...\n")
}
- err := idp.FetchJwks()
+ err := client.Provider.FetchJwks()
if err != nil {
return fmt.Errorf("failed to fetch JWK: %v", err)
} else {
@@ -260,7 +264,7 @@ func ForwardToken(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) error
}
ti := &oauth.TrustedIssuer{
- Issuer: idp.Issuer,
+ Issuer: client.Provider.Issuer,
Subject: "1",
ExpiresAt: time.Now().Add(time.Second * 3600),
}
diff --git a/internal/login.go b/internal/login.go
index e9c6374..d542974 100644
--- a/internal/login.go
+++ b/internal/login.go
@@ -12,19 +12,11 @@ import (
"time"
)
-func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider) error {
+func Login(config *Config) error {
if config == nil {
return fmt.Errorf("invalid config")
}
- if client == nil {
- return fmt.Errorf("invalid client")
- }
-
- if provider == nil {
- return fmt.Errorf("invalid identity provider")
- }
-
// make cache if it's not where expect
_, err := cache.CreateIdentityProvidersIfNotExists(config.Options.CachePath)
if err != nil {
@@ -39,18 +31,12 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider
}
// print the authorization URL for sharing
- var authorizationUrl = client.BuildAuthorizationUrl(provider.Endpoints.Authorization, state)
s := NewServerWithConfig(config)
- fmt.Printf("Login with external identity provider:\n\n %s/login\n %s\n\n",
- s.GetListenAddr(), authorizationUrl,
- )
+ s.State = state
- var button = MakeButton(authorizationUrl, "Login with "+client.Name)
var authzClient = oauth.NewClient()
authzClient.Scope = config.Authorization.Token.Scope
- // authorize oauth client and listen for callback from provider
- fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", s.GetListenAddr())
params := server.ServerParams{
Verbose: config.Options.Verbose,
AuthProvider: &oidc.IdentityProvider{
@@ -66,8 +52,7 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider
Register: config.Authorization.Endpoints.Register,
},
JwtBearerParams: flows.JwtBearerFlowParams{
- Client: authzClient,
- IdentityProvider: provider,
+ Client: authzClient,
TrustedIssuer: &oauth.TrustedIssuer{
AllowAnySubject: false,
Issuer: s.Addr,
@@ -87,7 +72,7 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider
Client: authzClient,
},
}
- err = s.StartLogin(button, provider, client, params)
+ err = s.StartLogin(config.Authentication.Clients, params)
if errors.Is(err, http.ErrServerClosed) {
fmt.Printf("\n=========================================\nServer closed.\n=========================================\n\n")
} else if err != nil {
@@ -96,7 +81,7 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider
} else if config.Options.FlowType == "client_credentials" {
params := flows.ClientCredentialsFlowParams{
- Client: client,
+ Client: nil, // # FIXME: need to do something about this being nil I think
}
_, err := NewClientCredentialsFlowWithConfig(config, params)
if err != nil {
@@ -108,13 +93,3 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider
return nil
}
-
-func MakeButton(url string, text string) string {
- // check if we have http:// a
- html := " " + text + ""
-}
diff --git a/internal/new.go b/internal/new.go
index 2b2bad8..55d9cda 100644
--- a/internal/new.go
+++ b/internal/new.go
@@ -29,7 +29,7 @@ func NewClientWithConfig(config *Config) *oauth.Client {
Id: clients[0].Id,
Secret: clients[0].Secret,
Name: clients[0].Name,
- Issuer: clients[0].Issuer,
+ Provider: clients[0].Provider,
Scope: clients[0].Scope,
RedirectUris: clients[0].RedirectUris,
}
@@ -53,7 +53,7 @@ func NewClientWithConfigByName(config *Config, name string) *oauth.Client {
func NewClientWithConfigByProvider(config *Config, issuer string) *oauth.Client {
index := slices.IndexFunc(config.Authentication.Clients, func(c oauth.Client) bool {
- return c.Issuer == issuer
+ return c.Provider.Issuer == issuer
})
if index >= 0 {
diff --git a/internal/oauth/authenticate.go b/internal/oauth/authenticate.go
index 5724526..b579e8e 100644
--- a/internal/oauth/authenticate.go
+++ b/internal/oauth/authenticate.go
@@ -16,12 +16,15 @@ func (client *Client) IsFlowInitiated() bool {
return client.FlowId != ""
}
-func (client *Client) BuildAuthorizationUrl(issuer string, state string) string {
- return issuer + "?" + "client_id=" + client.Id +
+func (client *Client) BuildAuthorizationUrl(state string) string {
+ url := client.Provider.Endpoints.Authorization + "?client_id=" + client.Id +
"&redirect_uri=" + url.QueryEscape(strings.Join(client.RedirectUris, ",")) +
"&response_type=code" + // this has to be set to "code"
- "&state=" + state +
"&scope=" + strings.Join(client.Scope, "+")
+ if state != "" {
+ url += "&state=" + state
+ }
+ return url
}
func (client *Client) InitiateLoginFlow(loginUrl string) error {
@@ -90,7 +93,7 @@ func (client *Client) FetchCSRFToken(flowUrl string) error {
return fmt.Errorf("failed to extract CSRF token: not found")
}
-func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl string, state string) ([]byte, error) {
+func (client *Client) FetchTokenFromAuthenticationServer(code string, state string) ([]byte, error) {
body := url.Values{
"grant_type": {"authorization_code"},
"client_id": {client.Id},
@@ -104,7 +107,7 @@ func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl
if state != "" {
body["state"] = []string{state}
}
- res, err := http.PostForm(remoteUrl, body)
+ res, err := http.PostForm(client.Provider.Endpoints.Token, body)
if err != nil {
return nil, fmt.Errorf("failed to get ID token: %s", err)
}
diff --git a/internal/oauth/client.go b/internal/oauth/client.go
index 2040117..78ab9cc 100644
--- a/internal/oauth/client.go
+++ b/internal/oauth/client.go
@@ -1,6 +1,7 @@
package oauth
import (
+ "davidallendj/opaal/internal/oidc"
"encoding/json"
"fmt"
"net/http"
@@ -24,15 +25,15 @@ const (
type Client struct {
http.Client
- Id string `db:"id" yaml:"id"`
- Secret string `db:"secret" yaml:"secret"`
- Name string `db:"name" yaml:"name"`
- Description string `db:"description" yaml:"description"`
- Issuer string `db:"issuer" yaml:"issuer"`
- RegistrationAccessToken string `db:"registration_access_token" yaml:"registration-access-token"`
- RedirectUris []string `db:"redirect_uris" yaml:"redirect-uris"`
- Scope []string `db:"scope" yaml:"scope"`
- Audience []string `db:"audience" yaml:"audience"`
+ Id string `db:"id" yaml:"id"`
+ Secret string `db:"secret" yaml:"secret"`
+ Name string `db:"name" yaml:"name"`
+ Description string `db:"description" yaml:"description"`
+ Provider oidc.IdentityProvider `db:"issuer" yaml:"provider"`
+ RegistrationAccessToken string `db:"registration_access_token" yaml:"registration-access-token"`
+ RedirectUris []string `db:"redirect_uris" yaml:"redirect-uris"`
+ Scope []string `db:"scope" yaml:"scope"`
+ Audience []string `db:"audience" yaml:"audience"`
FlowId string
CsrfToken string
}
diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go
index af813ab..2ec1f9c 100644
--- a/internal/oidc/oidc.go
+++ b/internal/oidc/oidc.go
@@ -111,26 +111,11 @@ func (p *IdentityProvider) LoadServerConfig(path string) error {
}
func (p *IdentityProvider) FetchServerConfig() error {
- // make a request to a server's openid-configuration
- req, err := http.NewRequest(http.MethodGet, p.Issuer+"/.well-known/openid-configuration", bytes.NewBuffer([]byte{}))
+ tmp, err := FetchServerConfig(p.Issuer)
if err != nil {
- return fmt.Errorf("failed to create a new request: %v", err)
- }
-
- client := &http.Client{} // temp client to get info and not used in flow
- res, err := client.Do(req)
- if err != nil {
- return fmt.Errorf("failed to do request: %v", err)
- }
-
- body, err := io.ReadAll(res.Body)
- if err != nil {
- return fmt.Errorf("failed to read response body: %v", err)
- }
- err = p.ParseServerConfig(body)
- if err != nil {
- return fmt.Errorf("failed to parse server config: %v", err)
+ return err
}
+ p = tmp
return nil
}
@@ -147,10 +132,15 @@ func FetchServerConfig(issuer string) (*IdentityProvider, error) {
return nil, fmt.Errorf("failed to do request: %v", err)
}
+ if res.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("HTTP status code: %d", res.StatusCode)
+ }
+
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
}
+
var p IdentityProvider
err = p.ParseServerConfig(body)
if err != nil {
diff --git a/internal/server/server.go b/internal/server/server.go
index 84c6aa4..3409630 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -58,10 +58,12 @@ func (s *Server) GetListenAddr() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}
-func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, client *oauth.Client, params ServerParams) error {
+func (s *Server) StartLogin(clients []oauth.Client, params ServerParams) error {
var (
- target = ""
- callback = ""
+ target string
+ callback string
+ client *oauth.Client
+ sso string
)
// check if callback is set
@@ -69,6 +71,29 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli
callback = "/oidc/callback"
}
+ // make the login page SSO buttons and authorization URLs to write to stdout
+ buttons := ""
+ fmt.Printf("Login with external identity providers: \n")
+ for i, client := range clients {
+ // fetch provider configuration before adding button
+ p, err := oidc.FetchServerConfig(client.Provider.Issuer)
+ if err != nil {
+ fmt.Printf("failed to fetch server config: %v\n", err)
+ continue
+ }
+
+ // if we're able to get the config, go ahead and try to fetch jwks too
+ if err = p.FetchJwks(); err != nil {
+ fmt.Printf("failed to fetch JWKS: %v\n", err)
+ continue
+ }
+
+ clients[i].Provider = *p
+ buttons += makeButton(fmt.Sprintf("/login?sso=%s", client.Id), client.Name)
+ url := client.BuildAuthorizationUrl(s.State)
+ fmt.Printf("\t%s\n", url)
+ }
+
var code string
var accessToken string
r := chi.NewRouter()
@@ -93,21 +118,32 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli
// add target if query exists
if r != nil {
target = r.URL.Query().Get("target")
- sso := r.URL.Query().Get("sso")
+ sso = r.URL.Query().Get("sso")
+
+ // TODO: get client from list and build the authorization URL string
+ index := slices.IndexFunc(clients, func(c oauth.Client) bool {
+ return c.Id == sso
+ })
+
+ // TODO: redirect the user to authorization URL and return from func
+ foundClient := index >= 0
+ if foundClient {
+ client = &clients[index]
+
+ url := client.BuildAuthorizationUrl(s.State)
+ fmt.Printf("Redirect URL: %s\n", url)
+ http.Redirect(w, r, url, http.StatusFound)
+ return
+ }
}
+
// show login page with notice to redirect
template, err := gonja.FromFile("pages/index.html")
if err != nil {
panic(err)
}
- // form, err := os.ReadFile("pages/login.html")
- // if err != nil {
- // fmt.Printf("failed to load login form: %v", err)
- // }
-
data := exec.NewContext(map[string]interface{}{
- // "loginForm": string(form),
"loginButtons": buttons,
})
@@ -158,7 +194,7 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli
// use refresh token provided to do a refresh token grant
refreshToken := r.URL.Query().Get("refresh-token")
if refreshToken != "" {
- _, err := params.JwtBearerParams.Client.PerformRefreshTokenGrant(provider.Endpoints.Token, refreshToken)
+ _, err := params.JwtBearerParams.Client.PerformRefreshTokenGrant(client.Provider.Endpoints.Token, refreshToken)
if err != nil {
fmt.Printf("failed to perform refresh token grant: %v\n", err)
http.Redirect(w, r, "/error", http.StatusInternalServerError)
@@ -196,10 +232,15 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli
fmt.Printf("Authorization code: %v\n", code)
}
+ // make sure we have the correct client to use
+ if client == nil {
+ fmt.Printf("failed to find valid client")
+ return
+ }
+
// use code from response and exchange for bearer token (with ID token)
bearerToken, err := client.FetchTokenFromAuthenticationServer(
code,
- provider.Endpoints.Token,
s.State,
)
if err != nil {
@@ -229,6 +270,7 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli
// complete JWT bearer flow to receive access token from authorization server
// fmt.Printf("bearer: %v\n", string(bearerToken))
params.JwtBearerParams.IdToken = data["id_token"].(string)
+ params.JwtBearerParams.Client = client
accessToken, err = flows.NewJwtBearerFlow(params.JwtBearerEndpoints, params.JwtBearerParams)
if err != nil {
fmt.Printf("failed to complete JWT bearer flow: %v\n", err)
@@ -407,10 +449,12 @@ func (s *Server) StartIdentityProvider() error {
// example username and password so do simplified authorization code flow
if username == "ochami" && password == "ochami" {
client := oauth.Client{
- Id: "ochami",
- Secret: "ochami",
- Name: "ochami",
- Issuer: "http://127.0.0.1:3333",
+ Id: "ochami",
+ Secret: "ochami",
+ Name: "ochami",
+ Provider: oidc.IdentityProvider{
+ Issuer: "http://127.0.0.1:3333",
+ },
RedirectUris: []string{fmt.Sprintf("http://%s:%d%s", s.Host, s.Port, callback)},
}
@@ -542,3 +586,13 @@ func (s *Server) StartIdentityProvider() error {
s.Handler = r
return s.ListenAndServe()
}
+
+func makeButton(url string, text string) string {
+ // check if we have http:// a
+ html := "", text)
+ return html
+ // return " " + text + ""
+}