From eb43261b911c83e8a42e6dde0039b006e481bf1f Mon Sep 17 00:00:00 2001 From: "David J. Allen" Date: Sat, 24 Feb 2024 21:05:05 -0700 Subject: [PATCH] Refactoring and more implementation --- internal/config.go | 5 +- internal/login.go | 184 ++++++++++++++++++++++++ internal/oidc/oidc.go | 2 +- internal/opaal.go | 321 ++++++++++++++---------------------------- internal/util/util.go | 10 +- 5 files changed, 293 insertions(+), 229 deletions(-) create mode 100644 internal/login.go diff --git a/internal/config.go b/internal/config.go index d7c0756..e507c2f 100644 --- a/internal/config.go +++ b/internal/config.go @@ -1,7 +1,6 @@ package opaal import ( - "davidallendj/opaal/internal/oauth" "davidallendj/opaal/internal/oidc" "davidallendj/opaal/internal/util" "log" @@ -14,7 +13,7 @@ import ( type Config struct { Version string `yaml:"version"` Server Server `yaml:"server"` - Client oauth.Client `yaml:"client"` + Client Client `yaml:"client"` IdentityProvider oidc.IdentityProvider `yaml:"oidc"` State string `yaml:"state"` ResponseType string `yaml:"response-type"` @@ -32,7 +31,7 @@ func NewConfig() Config { Host: "127.0.0.1", Port: 3333, }, - Client: oauth.Client{ + Client: Client{ Id: "", Secret: "", RedirectUris: []string{""}, diff --git a/internal/login.go b/internal/login.go new file mode 100644 index 0000000..c84454c --- /dev/null +++ b/internal/login.go @@ -0,0 +1,184 @@ +package opaal + +import ( + "davidallendj/opaal/internal/oidc" + "davidallendj/opaal/internal/util" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" +) + +func Login(config *Config) error { + if config == nil { + return fmt.Errorf("config is not valid") + } + + // initialize client that will be used throughout login flow + server := NewServerWithConfig(config) + client := NewClientWithConfig(config) + + // try and fetch server configuration if provided URL + idp := oidc.NewIdentityProvider() + if config.ActionUrls.ServerConfig != "" { + fmt.Printf("Fetching server configuration: %s\n", config.ActionUrls.ServerConfig) + err := idp.FetchServerConfig(config.ActionUrls.ServerConfig) + if err != nil { + return fmt.Errorf("failed to fetch server config: %v", err) + } + } else { + // otherwise, use what's provided in config file + idp.Issuer = config.IdentityProvider.Issuer + idp.Endpoints = config.IdentityProvider.Endpoints + idp.Supported = config.IdentityProvider.Supported + } + + // check if all appropriate parameters are set in config + if !hasRequiredParams(config) { + return fmt.Errorf("client ID must be set") + } + + // build the authorization URL to redirect user for social sign-in + var authorizationUrl = client.BuildAuthorizationUrl( + idp.Endpoints.Authorization, + config.State, + config.ResponseType, + config.Scope, + ) + + // print the authorization URL for sharing + fmt.Printf("Login with identity provider:\n\n %s/login\n %s\n\n", + server.GetListenAddr(), authorizationUrl, + ) + + // automatically open browser to initiate login flow (only useful for testing) + if config.OpenBrowser { + util.OpenUrl(authorizationUrl) + } + + // authorize oauth client and listen for callback from provider + fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", server.GetListenAddr()) + code, err := server.WaitForAuthorizationCode(authorizationUrl) + if errors.Is(err, http.ErrServerClosed) { + fmt.Printf("Server closed.\n") + } else if err != nil { + return fmt.Errorf("failed to start server: %s", err) + } + + if client == nil { + fmt.Printf("client did not initialize\n") + } + + // use code from response and exchange for bearer token (with ID token) + tokenString, err := client.FetchTokenFromAuthenticationServer( + code, + idp.Endpoints.Token, + config.State, + ) + if err != nil { + return fmt.Errorf("failed to fetch token from issuer: %v", err) + } + + // unmarshal data to get id_token and access_token + var data map[string]any + err = json.Unmarshal([]byte(tokenString), &data) + if err != nil { + return fmt.Errorf("failed to unmarshal token: %v", err) + } + + // extract ID token from bearer as JSON string for easy consumption + idToken := data["id_token"].(string) + idJwtSegments, err := util.DecodeJwt(idToken) + if err != nil { + fmt.Printf("failed to parse ID token: %v\n", err) + } else { + if config.DecodeIdToken { + if err != nil { + fmt.Printf("failed to decode JWT: %v\n", err) + } else { + fmt.Printf("id_token.header: %s\nid_token.payload: %s\n", string(idJwtSegments[0]), string(idJwtSegments[1])) + } + } + } + + // extract the access token to get the scopes + // accessToken := data["access_token"].(string) + // accessJwtSegments, err := util.DecodeJwt(accessToken) + // if err != nil || len(accessJwtSegments) <= { + // fmt.Printf("failed to parse access token: %v\n", err) + // } else { + // if config.DecodeIdToken { + // if err != nil { + // fmt.Printf("failed to decode JWT: %v\n", err) + // } else { + // fmt.Printf("access_token.header: %s\naccess_token.payload: %s\n", string(accessJwtSegments[0]), string(accessJwtSegments[1])) + // } + // } + // } + + // create a new identity with identity and session manager if url is provided + if config.ActionUrls.Identities != "" { + fmt.Printf("Attempting to create a new identity...\n") + _, err := client.CreateIdentity(config.ActionUrls.Identities, idToken) + if err != nil { + return fmt.Errorf("failed to create new identity: %v", err) + } + _, err = client.FetchIdentities(config.ActionUrls.Identities) + if err != nil { + return fmt.Errorf("failed to fetch identities: %v", err) + } + } + + // extract the subject from ID token claims + var subject string + var idJsonPayload map[string]any + var idJwtPayload []byte = idJwtSegments[1] + if idJwtPayload != nil { + err := json.Unmarshal(idJwtPayload, &idJsonPayload) + if err != nil { + return fmt.Errorf("failed to unmarshal JWT: %v", err) + } + subject = idJsonPayload["sub"].(string) + } else { + return fmt.Errorf("failed to extract subject from ID token claims") + } + + // extract the scope from access token claims + // var scope []string + // var accessJsonPayload map[string]any + // var accessJwtPayload []byte = accessJwtSegments[1] + // if accessJsonPayload != nil { + // err := json.Unmarshal(accessJwtPayload, &accessJsonPayload) + // if err != nil { + // return fmt.Errorf("failed to unmarshal JWT: %v", err) + // } + // scope = idJsonPayload["scope"].([]string) + // } + + // fetch JWKS and add issuer to authentication server to submit ID token + fmt.Printf("Fetching JWKS from authentication server for verification...\n") + err = idp.FetchJwk(config.ActionUrls.JwksUri) + if err != nil { + fmt.Printf("failed to fetch JWK: %v\n", err) + } else { + fmt.Printf("Attempting to add issuer to authorization server...\n") + _, err = client.AddTrustedIssuer(config.ActionUrls.TrustedIssuers, idp, subject, time.Duration(1000), config.Scope) + if err != nil { + return fmt.Errorf("failed to add trusted issuer: %v", err) + } + } + + // use ID token/user info to fetch access token from authentication server + if config.ActionUrls.AccessToken != "" { + fmt.Printf("Fetching access token from authorization server...\n") + accessToken, err := client.FetchTokenFromAuthorizationServer(config.ActionUrls.AccessToken, idToken, config.Scope) + if err != nil { + return fmt.Errorf("failed to fetch access token: %v", err) + } + fmt.Printf("%s\n", accessToken) + } + + fmt.Printf("Success!") + return nil +} diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index f1c0a80..e17a120 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -116,7 +116,7 @@ func (p *IdentityProvider) FetchServerConfig(url string) error { return fmt.Errorf("failed to create a new request: %v", err) } - client := &http.Client{} + client := &http.Client{} // temp client to get info and not used in flow res, err := client.Do(req) if err != nil { return fmt.Errorf("failed to do request: %v", err) diff --git a/internal/opaal.go b/internal/opaal.go index e3dd05e..ec2c1da 100644 --- a/internal/opaal.go +++ b/internal/opaal.go @@ -2,24 +2,34 @@ package opaal import ( "bytes" - "davidallendj/opaal/internal/oauth" "davidallendj/opaal/internal/oidc" "davidallendj/opaal/internal/util" "encoding/json" - "errors" "fmt" "io" "net/http" + "net/http/cookiejar" "net/url" + "os" "strings" "time" + + "golang.org/x/net/publicsuffix" ) type Server struct { + http.Server Host string `yaml:"host"` Port int `yaml:"port"` } +type Client struct { + http.Client + Id string `yaml:"id"` + Secret string `yaml:"secret"` + RedirectUris []string `yaml:"redirect-uris"` +} + type ActionUrls struct { Identities string `yaml:"identities"` TrustedIssuers string `yaml:"trusted-issuers"` @@ -28,173 +38,51 @@ type ActionUrls struct { JwksUri string `yaml:"jwks_uri"` } -func Login(config *Config) error { - if config == nil { - return fmt.Errorf("config is not valid") +func NewServerWithConfig(config *Config) *Server { + host := config.Server.Host + port := config.Server.Port + server := &Server{ + Host: host, + Port: port, } - // try and fetch server configuration if provided URL - idp := oidc.NewIdentityProvider() - if config.ActionUrls.ServerConfig != "" { - fmt.Printf("Fetching server configuration: %s\n", config.ActionUrls.ServerConfig) - err := idp.FetchServerConfig(config.ActionUrls.ServerConfig) - if err != nil { - return fmt.Errorf("failed to fetch server config: %v", err) - } - } else { - // otherwise, use what's provided in config file - idp.Issuer = config.IdentityProvider.Issuer - idp.Endpoints = config.IdentityProvider.Endpoints - idp.Supported = config.IdentityProvider.Supported - } - - // check if all appropriate parameters are set in config - if !hasRequiredParams(config) { - return fmt.Errorf("client ID must be set") - } - - // build the authorization URL to redirect user for social sign-in - var authorizationUrl = util.BuildAuthorizationUrl( - idp.Endpoints.Authorization, - config.Client.Id, - config.Client.RedirectUris, - config.State, - config.ResponseType, - config.Scope, - ) - - // print the authorization URL for sharing - serverAddr := fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port) - fmt.Printf("Login with identity provider:\n\n %s/login\n %s\n\n", - serverAddr, authorizationUrl, - ) - - // automatically open browser to initiate login flow (only useful for testing) - if config.OpenBrowser { - util.OpenUrl(authorizationUrl) - } - - // authorize oauth client and listen for callback from provider - fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", serverAddr) - code, err := WaitForAuthorizationCode(serverAddr, authorizationUrl) - if errors.Is(err, http.ErrServerClosed) { - fmt.Printf("Server closed.\n") - } else if err != nil { - return fmt.Errorf("failed to start server: %s", err) - } - - // use code from response and exchange for bearer token (with ID token) - tokenString, err := FetchIssuerToken( - code, - idp.Endpoints.Token, - config.Client, - config.State, - ) - if err != nil { - return fmt.Errorf("failed to fetch token from issuer: %v", err) - } - - // unmarshal data to get id_token and access_token - var data map[string]any - err = json.Unmarshal([]byte(tokenString), &data) - if err != nil { - return fmt.Errorf("failed to unmarshal token: %v", err) - } - - // extract ID token from bearer as JSON string for easy consumption - idToken := data["id_token"].(string) - idJwtSegments, err := util.DecodeJwt(idToken) - if err != nil { - fmt.Printf("failed to parse ID token: %v\n", err) - } else { - if config.DecodeIdToken { - if err != nil { - fmt.Printf("failed to decode JWT: %v\n", err) - } else { - fmt.Printf("id_token.header: %s\nid_token.payload: %s\n", string(idJwtSegments[0]), string(idJwtSegments[1])) - } - } - } - - // extract the access token to get the scopes - // accessToken := data["access_token"].(string) - // accessJwtSegments, err := util.DecodeJwt(accessToken) - // if err != nil || len(accessJwtSegments) <= { - // fmt.Printf("failed to parse access token: %v\n", err) - // } else { - // if config.DecodeIdToken { - // if err != nil { - // fmt.Printf("failed to decode JWT: %v\n", err) - // } else { - // fmt.Printf("access_token.header: %s\naccess_token.payload: %s\n", string(accessJwtSegments[0]), string(accessJwtSegments[1])) - // } - // } - // } - - // create a new identity with identity and session manager if url is provided - if config.ActionUrls.Identities != "" { - CreateIdentity(config.ActionUrls.Identities, idToken) - FetchIdentities(config.ActionUrls.Identities) - } - - // extract the subject from ID token claims - var subject string - var idJsonPayload map[string]any - var idJwtPayload []byte = idJwtSegments[1] - if idJwtPayload != nil { - err := json.Unmarshal(idJwtPayload, &idJsonPayload) - if err != nil { - return fmt.Errorf("failed to unmarshal JWT: %v", err) - } - subject = idJsonPayload["sub"].(string) - } else { - return fmt.Errorf("failed to extract subject from ID token claims") - } - - // extract the scope from access token claims - // var scope []string - // var accessJsonPayload map[string]any - // var accessJwtPayload []byte = accessJwtSegments[1] - // if accessJsonPayload != nil { - // err := json.Unmarshal(accessJwtPayload, &accessJsonPayload) - // if err != nil { - // return fmt.Errorf("failed to unmarshal JWT: %v", err) - // } - // scope = idJsonPayload["scope"].([]string) - // } - - // fetch JWKS and add issuer to authentication server to submit ID token - fmt.Printf("Fetching JWKS for verification...\n") - err = idp.FetchJwk(config.ActionUrls.JwksUri) - if err != nil { - fmt.Printf("failed to fetch JWK: %v\n", err) - } else { - fmt.Printf("Attempting to add issuer to authorization server...\n") - err = AddTrustedIssuer(config.ActionUrls.TrustedIssuers, idp, subject, time.Duration(1000), config.Scope) - if err != nil { - return fmt.Errorf("failed to add trusted issuer: %v", err) - } - } - - // use ID token/user info to fetch access token from authentication server - if config.ActionUrls.AccessToken != "" { - fmt.Printf("Fetching access token from authorization server...") - accessToken, err := FetchAccessToken(config.ActionUrls.AccessToken, config.Client.Id, idToken, config.Scope) - if err != nil { - return fmt.Errorf("failed to fetch access token: %v", err) - } - fmt.Printf("%s\n", accessToken) - } - - fmt.Printf("Success!") - return nil + server.Addr = fmt.Sprintf("%s:%d", host, port) + return server } -func WaitForAuthorizationCode(serverAddr string, loginUrl string) (string, error) { +func NewClientWithConfig(config *Config) *Client { + jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + return &Client{ + Id: config.Client.Id, + Secret: config.Client.Secret, + RedirectUris: config.Client.RedirectUris, + Client: http.Client{Jar: jar}, + } +} + +func (s *Server) SetListenAddr(host string, port int) { + s.Host = host + s.Port = port + s.Addr = s.GetListenAddr() +} + +func (s *Server) GetListenAddr() string { + return fmt.Sprintf("%s:%d", s.Host, s.Port) +} + +func (s *Server) WaitForAuthorizationCode(loginUrl string) (string, error) { var code string - s := &http.Server{Addr: serverAddr} + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/login", http.StatusSeeOther) + }) http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { - // redirect directly to identity provider with this endpoint - http.Redirect(w, r, loginUrl, http.StatusSeeOther) + // show login page with notice to redirect + loginPage, err := os.ReadFile("pages/index.html") + if err != nil { + fmt.Printf("failed to load login page: %v\n", err) + } + loginPage = []byte(strings.ReplaceAll(string(loginPage), "{{loginUrl}}", loginUrl)) + w.WriteHeader(http.StatusSeeOther) + w.Write(loginPage) }) http.HandleFunc("/oidc/callback", func(w http.ResponseWriter, r *http.Request) { // get the code from the OIDC provider @@ -202,13 +90,28 @@ func WaitForAuthorizationCode(serverAddr string, loginUrl string) (string, error code = r.URL.Query().Get("code") fmt.Printf("Authorization code: %v\n", code) } + http.Redirect(w, r, s.Addr+"/success", http.StatusSeeOther) s.Close() }) return code, s.ListenAndServe() } -func FetchIssuerToken(code string, remoteUrl string, client oauth.Client, state string) (string, error) { - var token string +func (s *Server) ShowSuccessPage() error { + http.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) { + + }) + return s.ListenAndServe() +} + +func (client *Client) BuildAuthorizationUrl(authEndpoint string, state string, responseType string, scope []string) string { + return authEndpoint + "?" + "client_id=" + client.Id + + "&redirect_uri=" + util.URLEscape(strings.Join(client.RedirectUris, ",")) + + "&response_type=" + responseType + + "&state=" + state + + "&scope=" + strings.Join(scope, "+") +} + +func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl string, state string) ([]byte, error) { data := url.Values{ "grant_type": {"authorization_code"}, "code": {code}, @@ -219,77 +122,68 @@ func FetchIssuerToken(code string, remoteUrl string, client oauth.Client, state } res, err := http.PostForm(remoteUrl, data) if err != nil { - return "", fmt.Errorf("failed to get ID token: %s", err) + return nil, fmt.Errorf("failed to get ID token: %s", err) } defer res.Body.Close() - b, err := io.ReadAll(res.Body) - if err != nil { - return "", fmt.Errorf("failed to read response body: %v", err) - } - token = string(b) - - fmt.Printf("%v\n", token) - return token, nil + return io.ReadAll(res.Body) } -func FetchAccessToken(remoteUrl string, clientId string, jwt string, scopes []string) (string, error) { +func (client *Client) FetchTokenFromAuthorizationServer(remoteUrl string, jwt string, scope []string) ([]byte, error) { // hydra endpoint: /oauth/token - var token string - data := url.Values{ - "grant_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, - "assertion": {jwt}, - } - res, err := http.PostForm(remoteUrl, data) + data := "grant_type=" + util.URLEscape("urn:ietf:params:oauth:grant-type:jwt-bearer") + + "&assertion=" + jwt + + "&scope=" + strings.Join(scope, "+") + fmt.Printf("encoded params: %v\n\n", data) + req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer([]byte(data))) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") if err != nil { - return "", fmt.Errorf("failed to get token: %s", err) + return nil, fmt.Errorf("failed to make request: %s", err) + } + res, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to do request: %v", err) } defer res.Body.Close() - b, err := io.ReadAll(res.Body) - if err != nil { - return "", fmt.Errorf("failed to read response body: %v", err) - } - token = string(b) - - fmt.Printf("%v\n", token) - return token, nil + return io.ReadAll(res.Body) } -func AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvider, subject string, duration time.Duration, scope []string) error { +func (client *Client) AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvider, subject string, duration time.Duration, scope []string) ([]byte, error) { // hydra endpoint: /admin/trust/grants/jwt-bearer/issuers if idp == nil { - return fmt.Errorf("identity provided is nil") + return nil, fmt.Errorf("identity provided is nil") } jwkstr, err := json.Marshal(idp.Key) if err != nil { - return fmt.Errorf("failed to marshal JWK: %v", err) + return nil, fmt.Errorf("failed to marshal JWK: %v", err) } data := []byte(fmt.Sprintf(`{ - "allow_any_subject": false, + "allow_any_subject": true, "issuer": "%s", "subject": "%s" "expires_at": "%v" "jwk": %v, - "scope": [ j%s ], + "scope": [ %s ], }`, idp.Issuer, subject, time.Now().Add(duration), string(jwkstr), strings.Join(scope, ","))) req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer(data)) + // req.Header.Add("X-CSRF-Token", client.CsrfToken.Value) if err != nil { - return fmt.Errorf("failed to create a new request: %v", err) + return nil, fmt.Errorf("failed to make request: %v", err) } req.Header.Add("Content-Type", "application/json") // req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken)) - client := &http.Client{} res, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to do request: %v", err) + return nil, fmt.Errorf("failed to do request: %v", err) } - fmt.Printf("%d\n", res.StatusCode) - return nil + defer res.Body.Close() + + return io.ReadAll(res.Body) } -func CreateIdentity(remoteUrl string, idToken string) error { +func (client *Client) CreateIdentity(remoteUrl string, idToken string) ([]byte, error) { // kratos endpoint: /admin/identities data := []byte(`{ "schema_id": "preset://email", @@ -300,36 +194,31 @@ func CreateIdentity(remoteUrl string, idToken string) error { req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer(data)) if err != nil { - return fmt.Errorf("failed to create a new request: %v", err) + return nil, fmt.Errorf("failed to create a new request: %v", err) } req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken)) - client := &http.Client{} + // req.Header.Add("X-CSRF-Token", client.CsrfToken.Value) res, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to do request: %v", err) + return nil, fmt.Errorf("failed to do request: %v", err) } - fmt.Printf("%d\n", res.StatusCode) - return nil + + return io.ReadAll(res.Body) } -func FetchIdentities(remoteUrl string) error { +func (client *Client) FetchIdentities(remoteUrl string) ([]byte, error) { req, err := http.NewRequest("GET", remoteUrl, bytes.NewBuffer([]byte{})) if err != nil { - return fmt.Errorf("failed to create a new request: %v", err) + return nil, fmt.Errorf("failed to create a new request: %v", err) } - client := &http.Client{} res, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to do request: %v", err) + return nil, fmt.Errorf("failed to do request: %v", err) } - fmt.Printf("%v\n", res) - return nil -} -func RedirectSuccess() { - // show a success page with the user's access token + return io.ReadAll(res.Body) } func hasRequiredParams(config *Config) bool { diff --git a/internal/util/util.go b/internal/util/util.go index 470291b..3a28bc5 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -36,15 +36,7 @@ func RandomString(n int) string { return string(b) } -func BuildAuthorizationUrl(authEndpoint string, clientId string, redirectUri []string, state string, responseType string, scope []string) string { - return authEndpoint + "?" + "client_id=" + clientId + - "&redirect_uri=" + EncodeURL(strings.Join(redirectUri, ",")) + - "&response_type=" + responseType + - "&state=" + state + - "&scope=" + strings.Join(scope, "+") -} - -func EncodeURL(s string) string { +func URLEscape(s string) string { return url.QueryEscape(s) }