From 61a35c165dbb7ede545c0419a9a09babd405133a Mon Sep 17 00:00:00 2001 From: "David J. Allen" Date: Thu, 18 Apr 2024 16:02:43 -0600 Subject: [PATCH 1/2] WIP refactoring login --- cmd/login.go | 15 +++++++++++++++ internal/server/server.go | 1 + 2 files changed, 16 insertions(+) diff --git a/cmd/login.go b/cmd/login.go index abe3452..09b0b58 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -85,6 +85,11 @@ var loginCmd = &cobra.Command{ os.Exit(1) } + // use clients to make SSO buttons that + for _, client := range config.Authentication.Clients { + MakeButton() + } + // start the listener err := opaal.Login(&config, &client, provider) if err != nil { @@ -115,3 +120,13 @@ func init() { loginCmd.MarkFlagsMutuallyExclusive("target.name", "target.index") rootCmd.AddCommand(loginCmd) } + +func MakeButton(url string, text string) string { + // check if we have http:// a + html := " " + text + "" +} diff --git a/internal/server/server.go b/internal/server/server.go index 66c45d4..84c6aa4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -93,6 +93,7 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli // add target if query exists if r != nil { target = r.URL.Query().Get("target") + sso := r.URL.Query().Get("sso") } // show login page with notice to redirect template, err := gonja.FromFile("pages/index.html") From 6d2f488a6b570b6f5cc130b259725fa8908b2890 Mon Sep 17 00:00:00 2001 From: "David J. Allen" Date: Tue, 23 Apr 2024 13:17:41 -0600 Subject: [PATCH 2/2] Refactored login page and process --- cmd/login.go | 120 +++++++++++++++------------------ internal/flows/jwt_bearer.go | 36 +++++----- internal/login.go | 35 ++-------- internal/new.go | 4 +- internal/oauth/authenticate.go | 13 ++-- internal/oauth/client.go | 19 +++--- internal/oidc/oidc.go | 26 +++---- internal/server/server.go | 86 ++++++++++++++++++----- 8 files changed, 179 insertions(+), 160 deletions(-) diff --git a/cmd/login.go b/cmd/login.go index 09b0b58..855a0e8 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -2,12 +2,9 @@ package cmd import ( opaal "davidallendj/opaal/internal" - cache "davidallendj/opaal/internal/cache/sqlite" "davidallendj/opaal/internal/oauth" - "davidallendj/opaal/internal/oidc" "fmt" "os" - "slices" "github.com/spf13/cobra" ) @@ -25,73 +22,68 @@ var loginCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { for { // try and find client with valid identity provider config - var provider *oidc.IdentityProvider - if target != "" { - // only try to use client with name give - index := slices.IndexFunc(config.Authentication.Clients, func(c oauth.Client) bool { - return target == c.Name - }) - if index < 0 { - fmt.Printf("could not find the target client listed by name") - os.Exit(1) - } - client := config.Authentication.Clients[index] - _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer) - if err != nil { + // var provider *oidc.IdentityProvider + // if target != "" { + // // only try to use client with name give + // index := slices.IndexFunc(config.Authentication.Clients, func(c oauth.Client) bool { + // return target == c.Name + // }) + // if index < 0 { + // fmt.Printf("could not find the target client listed by name") + // os.Exit(1) + // } + // client := config.Authentication.Clients[index] + // _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer) + // if err != nil { - } + // } - } else if targetIndex >= 0 { - // only try to use client by index - targetCount := len(config.Authentication.Clients) - 1 - if targetIndex > targetCount { - fmt.Printf("target index out of range (found %d)", targetCount) - } - client := config.Authentication.Clients[targetIndex] - _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer) - if err != nil { + // } else if targetIndex >= 0 { + // // only try to use client by index + // targetCount := len(config.Authentication.Clients) - 1 + // if targetIndex > targetCount { + // fmt.Printf("target index out of range (found %d)", targetCount) + // } + // client := config.Authentication.Clients[targetIndex] + // _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer) + // if err != nil { - } - } else { - for _, c := range config.Authentication.Clients { - // try to get identity provider info locally first - _, err := cache.GetIdentityProvider(config.Options.CachePath, c.Issuer) - if err != nil && !config.Options.CacheOnly { - fmt.Printf("fetching config from issuer: %v\n", c.Issuer) - // try to get info remotely by fetching - provider, err = oidc.FetchServerConfig(c.Issuer) - if err != nil { - fmt.Printf("failed to fetch server config: %v\n", err) - continue - } - client = c - // fetch the provider's JWKS - err := provider.FetchJwks() - if err != nil { - fmt.Printf("failed to fetch JWKS: %v\n", err) - } - break - } - // only test the first if --run-all flag is not set - if !config.Authentication.TestAllClients { - fmt.Printf("stopping after first test...\n\n\n") - break - } - } - } + // } + // } else { + // for _, c := range config.Authentication.Clients { + // // try to get identity provider info locally first + // _, err := cache.GetIdentityProvider(config.Options.CachePath, c.Issuer) + // if err != nil && !config.Options.CacheOnly { + // fmt.Printf("fetching config from issuer: %v\n", c.Issuer) + // // try to get info remotely by fetching + // provider, err = oidc.FetchServerConfig(c.Issuer) + // if err != nil { + // fmt.Printf("failed to fetch server config: %v\n", err) + // continue + // } + // client = c + // // fetch the provider's JWKS + // err := provider.FetchJwks() + // if err != nil { + // fmt.Printf("failed to fetch JWKS: %v\n", err) + // } + // break + // } + // // only test the first if --run-all flag is not set + // if !config.Authentication.TestAllClients { + // fmt.Printf("stopping after first test...\n\n\n") + // break + // } + // } + // } - if provider == nil { - fmt.Printf("failed to retrieve provider config\n") - os.Exit(1) - } - - // use clients to make SSO buttons that - for _, client := range config.Authentication.Clients { - MakeButton() - } + // if provider == nil { + // fmt.Printf("failed to retrieve provider config\n") + // os.Exit(1) + // } // start the listener - err := opaal.Login(&config, &client, provider) + err := opaal.Login(&config) if err != nil { fmt.Printf("%v\n", err) os.Exit(1) diff --git a/internal/flows/jwt_bearer.go b/internal/flows/jwt_bearer.go index 652944b..b73ebdc 100644 --- a/internal/flows/jwt_bearer.go +++ b/internal/flows/jwt_bearer.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "crypto/rsa" "davidallendj/opaal/internal/oauth" - "davidallendj/opaal/internal/oidc" "encoding/json" "fmt" "os" @@ -19,14 +18,14 @@ import ( ) type JwtBearerFlowParams struct { - AccessToken string - IdToken string - IdentityProvider *oidc.IdentityProvider - TrustedIssuer *oauth.TrustedIssuer - Client *oauth.Client - Refresh bool - Verbose bool - KeyPath string + AccessToken string + IdToken string + // IdentityProvider *oidc.IdentityProvider + TrustedIssuer *oauth.TrustedIssuer + Client *oauth.Client + Refresh bool + Verbose bool + KeyPath string } type JwtBearerFlowEndpoints struct { @@ -39,22 +38,27 @@ type JwtBearerFlowEndpoints struct { func NewJwtBearerFlow(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) (string, error) { // 1. verify that the JWT from the issuer is valid using all keys var ( - idp = params.IdentityProvider + // idp = params.IdentityProvider accessToken = params.AccessToken idToken = params.IdToken client = params.Client trustedIssuer = params.TrustedIssuer verbose = params.Verbose ) + + // pre-condition checks to make sure certain variables are set + if client == nil { + return "", fmt.Errorf("invalid client (client is nil)") + } if accessToken != "" { - _, err := jws.Verify([]byte(accessToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true)) + _, err := jws.Verify([]byte(accessToken), jws.WithKeySet(client.Provider.KeySet), jws.WithValidateKey(true)) if err != nil { return "", fmt.Errorf("failed to verify access token: %v", err) } } if idToken != "" { - _, err := jws.Verify([]byte(idToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true)) + _, err := jws.Verify([]byte(idToken), jws.WithKeySet(client.Provider.KeySet), jws.WithValidateKey(true)) if err != nil { return "", fmt.Errorf("failed to verify ID token: %v", err) } @@ -126,7 +130,7 @@ func NewJwtBearerFlow(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) (s // TODO: add trusted issuer to cache if successful // 4. create a new JWT based on the claims from the identity provider and sign - parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKeySet(idp.KeySet)) + parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKeySet(client.Provider.KeySet)) if err != nil { return "", fmt.Errorf("failed to parse ID token: %v", err) } @@ -242,7 +246,7 @@ func ForwardToken(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) error var ( client = params.Client idToken = params.IdToken - idp = params.IdentityProvider + // idp = params.IdentityProvider verbose = params.Verbose ) @@ -250,7 +254,7 @@ func ForwardToken(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) error if verbose { fmt.Printf("Fetching JWKS from authentication server for verification...\n") } - err := idp.FetchJwks() + err := client.Provider.FetchJwks() if err != nil { return fmt.Errorf("failed to fetch JWK: %v", err) } else { @@ -260,7 +264,7 @@ func ForwardToken(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) error } ti := &oauth.TrustedIssuer{ - Issuer: idp.Issuer, + Issuer: client.Provider.Issuer, Subject: "1", ExpiresAt: time.Now().Add(time.Second * 3600), } diff --git a/internal/login.go b/internal/login.go index e9c6374..d542974 100644 --- a/internal/login.go +++ b/internal/login.go @@ -12,19 +12,11 @@ import ( "time" ) -func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider) error { +func Login(config *Config) error { if config == nil { return fmt.Errorf("invalid config") } - if client == nil { - return fmt.Errorf("invalid client") - } - - if provider == nil { - return fmt.Errorf("invalid identity provider") - } - // make cache if it's not where expect _, err := cache.CreateIdentityProvidersIfNotExists(config.Options.CachePath) if err != nil { @@ -39,18 +31,12 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider } // print the authorization URL for sharing - var authorizationUrl = client.BuildAuthorizationUrl(provider.Endpoints.Authorization, state) s := NewServerWithConfig(config) - fmt.Printf("Login with external identity provider:\n\n %s/login\n %s\n\n", - s.GetListenAddr(), authorizationUrl, - ) + s.State = state - var button = MakeButton(authorizationUrl, "Login with "+client.Name) var authzClient = oauth.NewClient() authzClient.Scope = config.Authorization.Token.Scope - // authorize oauth client and listen for callback from provider - fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", s.GetListenAddr()) params := server.ServerParams{ Verbose: config.Options.Verbose, AuthProvider: &oidc.IdentityProvider{ @@ -66,8 +52,7 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider Register: config.Authorization.Endpoints.Register, }, JwtBearerParams: flows.JwtBearerFlowParams{ - Client: authzClient, - IdentityProvider: provider, + Client: authzClient, TrustedIssuer: &oauth.TrustedIssuer{ AllowAnySubject: false, Issuer: s.Addr, @@ -87,7 +72,7 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider Client: authzClient, }, } - err = s.StartLogin(button, provider, client, params) + err = s.StartLogin(config.Authentication.Clients, params) if errors.Is(err, http.ErrServerClosed) { fmt.Printf("\n=========================================\nServer closed.\n=========================================\n\n") } else if err != nil { @@ -96,7 +81,7 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider } else if config.Options.FlowType == "client_credentials" { params := flows.ClientCredentialsFlowParams{ - Client: client, + Client: nil, // # FIXME: need to do something about this being nil I think } _, err := NewClientCredentialsFlowWithConfig(config, params) if err != nil { @@ -108,13 +93,3 @@ func Login(config *Config, client *oauth.Client, provider *oidc.IdentityProvider return nil } - -func MakeButton(url string, text string) string { - // check if we have http:// a - html := " " + text + "" -} diff --git a/internal/new.go b/internal/new.go index 2b2bad8..55d9cda 100644 --- a/internal/new.go +++ b/internal/new.go @@ -29,7 +29,7 @@ func NewClientWithConfig(config *Config) *oauth.Client { Id: clients[0].Id, Secret: clients[0].Secret, Name: clients[0].Name, - Issuer: clients[0].Issuer, + Provider: clients[0].Provider, Scope: clients[0].Scope, RedirectUris: clients[0].RedirectUris, } @@ -53,7 +53,7 @@ func NewClientWithConfigByName(config *Config, name string) *oauth.Client { func NewClientWithConfigByProvider(config *Config, issuer string) *oauth.Client { index := slices.IndexFunc(config.Authentication.Clients, func(c oauth.Client) bool { - return c.Issuer == issuer + return c.Provider.Issuer == issuer }) if index >= 0 { diff --git a/internal/oauth/authenticate.go b/internal/oauth/authenticate.go index 5724526..b579e8e 100644 --- a/internal/oauth/authenticate.go +++ b/internal/oauth/authenticate.go @@ -16,12 +16,15 @@ func (client *Client) IsFlowInitiated() bool { return client.FlowId != "" } -func (client *Client) BuildAuthorizationUrl(issuer string, state string) string { - return issuer + "?" + "client_id=" + client.Id + +func (client *Client) BuildAuthorizationUrl(state string) string { + url := client.Provider.Endpoints.Authorization + "?client_id=" + client.Id + "&redirect_uri=" + url.QueryEscape(strings.Join(client.RedirectUris, ",")) + "&response_type=code" + // this has to be set to "code" - "&state=" + state + "&scope=" + strings.Join(client.Scope, "+") + if state != "" { + url += "&state=" + state + } + return url } func (client *Client) InitiateLoginFlow(loginUrl string) error { @@ -90,7 +93,7 @@ func (client *Client) FetchCSRFToken(flowUrl string) error { return fmt.Errorf("failed to extract CSRF token: not found") } -func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl string, state string) ([]byte, error) { +func (client *Client) FetchTokenFromAuthenticationServer(code string, state string) ([]byte, error) { body := url.Values{ "grant_type": {"authorization_code"}, "client_id": {client.Id}, @@ -104,7 +107,7 @@ func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl if state != "" { body["state"] = []string{state} } - res, err := http.PostForm(remoteUrl, body) + res, err := http.PostForm(client.Provider.Endpoints.Token, body) if err != nil { return nil, fmt.Errorf("failed to get ID token: %s", err) } diff --git a/internal/oauth/client.go b/internal/oauth/client.go index 2040117..78ab9cc 100644 --- a/internal/oauth/client.go +++ b/internal/oauth/client.go @@ -1,6 +1,7 @@ package oauth import ( + "davidallendj/opaal/internal/oidc" "encoding/json" "fmt" "net/http" @@ -24,15 +25,15 @@ const ( type Client struct { http.Client - Id string `db:"id" yaml:"id"` - Secret string `db:"secret" yaml:"secret"` - Name string `db:"name" yaml:"name"` - Description string `db:"description" yaml:"description"` - Issuer string `db:"issuer" yaml:"issuer"` - RegistrationAccessToken string `db:"registration_access_token" yaml:"registration-access-token"` - RedirectUris []string `db:"redirect_uris" yaml:"redirect-uris"` - Scope []string `db:"scope" yaml:"scope"` - Audience []string `db:"audience" yaml:"audience"` + Id string `db:"id" yaml:"id"` + Secret string `db:"secret" yaml:"secret"` + Name string `db:"name" yaml:"name"` + Description string `db:"description" yaml:"description"` + Provider oidc.IdentityProvider `db:"issuer" yaml:"provider"` + RegistrationAccessToken string `db:"registration_access_token" yaml:"registration-access-token"` + RedirectUris []string `db:"redirect_uris" yaml:"redirect-uris"` + Scope []string `db:"scope" yaml:"scope"` + Audience []string `db:"audience" yaml:"audience"` FlowId string CsrfToken string } diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index af813ab..2ec1f9c 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -111,26 +111,11 @@ func (p *IdentityProvider) LoadServerConfig(path string) error { } func (p *IdentityProvider) FetchServerConfig() error { - // make a request to a server's openid-configuration - req, err := http.NewRequest(http.MethodGet, p.Issuer+"/.well-known/openid-configuration", bytes.NewBuffer([]byte{})) + tmp, err := FetchServerConfig(p.Issuer) if err != nil { - return fmt.Errorf("failed to create a new request: %v", err) - } - - client := &http.Client{} // temp client to get info and not used in flow - res, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to do request: %v", err) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %v", err) - } - err = p.ParseServerConfig(body) - if err != nil { - return fmt.Errorf("failed to parse server config: %v", err) + return err } + p = tmp return nil } @@ -147,10 +132,15 @@ func FetchServerConfig(issuer string) (*IdentityProvider, error) { return nil, fmt.Errorf("failed to do request: %v", err) } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP status code: %d", res.StatusCode) + } + body, err := io.ReadAll(res.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %v", err) } + var p IdentityProvider err = p.ParseServerConfig(body) if err != nil { diff --git a/internal/server/server.go b/internal/server/server.go index 84c6aa4..3409630 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -58,10 +58,12 @@ func (s *Server) GetListenAddr() string { return fmt.Sprintf("%s:%d", s.Host, s.Port) } -func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, client *oauth.Client, params ServerParams) error { +func (s *Server) StartLogin(clients []oauth.Client, params ServerParams) error { var ( - target = "" - callback = "" + target string + callback string + client *oauth.Client + sso string ) // check if callback is set @@ -69,6 +71,29 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli callback = "/oidc/callback" } + // make the login page SSO buttons and authorization URLs to write to stdout + buttons := "" + fmt.Printf("Login with external identity providers: \n") + for i, client := range clients { + // fetch provider configuration before adding button + p, err := oidc.FetchServerConfig(client.Provider.Issuer) + if err != nil { + fmt.Printf("failed to fetch server config: %v\n", err) + continue + } + + // if we're able to get the config, go ahead and try to fetch jwks too + if err = p.FetchJwks(); err != nil { + fmt.Printf("failed to fetch JWKS: %v\n", err) + continue + } + + clients[i].Provider = *p + buttons += makeButton(fmt.Sprintf("/login?sso=%s", client.Id), client.Name) + url := client.BuildAuthorizationUrl(s.State) + fmt.Printf("\t%s\n", url) + } + var code string var accessToken string r := chi.NewRouter() @@ -93,21 +118,32 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli // add target if query exists if r != nil { target = r.URL.Query().Get("target") - sso := r.URL.Query().Get("sso") + sso = r.URL.Query().Get("sso") + + // TODO: get client from list and build the authorization URL string + index := slices.IndexFunc(clients, func(c oauth.Client) bool { + return c.Id == sso + }) + + // TODO: redirect the user to authorization URL and return from func + foundClient := index >= 0 + if foundClient { + client = &clients[index] + + url := client.BuildAuthorizationUrl(s.State) + fmt.Printf("Redirect URL: %s\n", url) + http.Redirect(w, r, url, http.StatusFound) + return + } } + // show login page with notice to redirect template, err := gonja.FromFile("pages/index.html") if err != nil { panic(err) } - // form, err := os.ReadFile("pages/login.html") - // if err != nil { - // fmt.Printf("failed to load login form: %v", err) - // } - data := exec.NewContext(map[string]interface{}{ - // "loginForm": string(form), "loginButtons": buttons, }) @@ -158,7 +194,7 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli // use refresh token provided to do a refresh token grant refreshToken := r.URL.Query().Get("refresh-token") if refreshToken != "" { - _, err := params.JwtBearerParams.Client.PerformRefreshTokenGrant(provider.Endpoints.Token, refreshToken) + _, err := params.JwtBearerParams.Client.PerformRefreshTokenGrant(client.Provider.Endpoints.Token, refreshToken) if err != nil { fmt.Printf("failed to perform refresh token grant: %v\n", err) http.Redirect(w, r, "/error", http.StatusInternalServerError) @@ -196,10 +232,15 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli fmt.Printf("Authorization code: %v\n", code) } + // make sure we have the correct client to use + if client == nil { + fmt.Printf("failed to find valid client") + return + } + // use code from response and exchange for bearer token (with ID token) bearerToken, err := client.FetchTokenFromAuthenticationServer( code, - provider.Endpoints.Token, s.State, ) if err != nil { @@ -229,6 +270,7 @@ func (s *Server) StartLogin(buttons string, provider *oidc.IdentityProvider, cli // complete JWT bearer flow to receive access token from authorization server // fmt.Printf("bearer: %v\n", string(bearerToken)) params.JwtBearerParams.IdToken = data["id_token"].(string) + params.JwtBearerParams.Client = client accessToken, err = flows.NewJwtBearerFlow(params.JwtBearerEndpoints, params.JwtBearerParams) if err != nil { fmt.Printf("failed to complete JWT bearer flow: %v\n", err) @@ -407,10 +449,12 @@ func (s *Server) StartIdentityProvider() error { // example username and password so do simplified authorization code flow if username == "ochami" && password == "ochami" { client := oauth.Client{ - Id: "ochami", - Secret: "ochami", - Name: "ochami", - Issuer: "http://127.0.0.1:3333", + Id: "ochami", + Secret: "ochami", + Name: "ochami", + Provider: oidc.IdentityProvider{ + Issuer: "http://127.0.0.1:3333", + }, RedirectUris: []string{fmt.Sprintf("http://%s:%d%s", s.Host, s.Port, callback)}, } @@ -542,3 +586,13 @@ func (s *Server) StartIdentityProvider() error { s.Handler = r return s.ListenAndServe() } + +func makeButton(url string, text string) string { + // check if we have http:// a + html := "", text) + return html + // return " " + text + "" +}