Added flow initialization and completed implementation

This commit is contained in:
David Allen 2024-02-26 09:20:07 -07:00
parent a3f0caf4ff
commit 1de4d3a5d5
No known key found for this signature in database
GPG key ID: 1D2A29322FBB6FCB
6 changed files with 262 additions and 41 deletions

View file

@ -21,6 +21,8 @@ type Client struct {
Id string `yaml:"id"` Id string `yaml:"id"`
Secret string `yaml:"secret"` Secret string `yaml:"secret"`
RedirectUris []string `yaml:"redirect-uris"` RedirectUris []string `yaml:"redirect-uris"`
FlowId string
CsrfToken string
} }
func NewClientWithConfig(config *Config) *Client { func NewClientWithConfig(config *Config) *Client {
@ -33,6 +35,10 @@ func NewClientWithConfig(config *Config) *Client {
} }
} }
func (client *Client) IsFlowInitiated() bool {
return client.FlowId != ""
}
func (client *Client) BuildAuthorizationUrl(authEndpoint string, state string, responseType string, scope []string) string { func (client *Client) BuildAuthorizationUrl(authEndpoint string, state string, responseType string, scope []string) string {
return authEndpoint + "?" + "client_id=" + client.Id + return authEndpoint + "?" + "client_id=" + client.Id +
"&redirect_uri=" + util.URLEscape(strings.Join(client.RedirectUris, ",")) + "&redirect_uri=" + util.URLEscape(strings.Join(client.RedirectUris, ",")) +
@ -41,6 +47,80 @@ func (client *Client) BuildAuthorizationUrl(authEndpoint string, state string, r
"&scope=" + strings.Join(scope, "+") "&scope=" + strings.Join(scope, "+")
} }
func (client *Client) InitiateLoginFlow(loginUrl string) error {
// kratos: GET /self-service/login/api
req, err := http.NewRequest("GET", loginUrl, bytes.NewBuffer([]byte{}))
if err != nil {
return fmt.Errorf("failed to make request: %v", err)
}
res, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to do request: %v", err)
}
defer res.Body.Close()
// get the flow ID from response
body, err := io.ReadAll(res.Body)
var flowData map[string]any
err = json.Unmarshal(body, &flowData)
if err != nil {
return fmt.Errorf("failed to unmarshal flow data: %v\n%v", err, string(body))
} else {
client.FlowId = flowData["id"].(string)
}
return nil
}
func (client *Client) FetchFlowData(flowUrl string) (JsonObject, error) {
//kratos: GET /self-service/login/flows?id={flowId}
// replace {id} in string with actual value
flowUrl = strings.ReplaceAll(flowUrl, "{id}", client.FlowId)
req, err := http.NewRequest("GET", flowUrl, nil)
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
}
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to do request: %v", err)
}
defer res.Body.Close()
// get the flow data from response
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
}
var flowData JsonObject
err = json.Unmarshal(body, &flowData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal flow data: %v", err)
}
return flowData, nil
}
func (client *Client) FetchCSRFToken(flowUrl string) error {
data, err := client.FetchFlowData(flowUrl)
if err != nil {
return fmt.Errorf("failed to fetch flow data: %v", err)
}
// iterate through nodes and extract the CSRF token attribute from the flow data
ui := data["ui"].(map[string]any)
nodes := ui["nodes"].([]any)
for _, node := range nodes {
attrs := node.(map[string]any)["attributes"].(map[string]any)
name := attrs["name"].(string)
if name == "csrf_token" {
client.CsrfToken = attrs["value"].(string)
return nil
}
}
return fmt.Errorf("failed to extract CSRF token: not found")
}
func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl string, state string) ([]byte, error) { func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl string, state string) ([]byte, error) {
data := url.Values{ data := url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
@ -56,14 +136,19 @@ func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl
} }
defer res.Body.Close() defer res.Body.Close()
domain, _ := url.Parse("http://127.0.0.1")
client.Jar.SetCookies(domain, res.Cookies())
return io.ReadAll(res.Body) return io.ReadAll(res.Body)
} }
func (client *Client) FetchTokenFromAuthorizationServer(remoteUrl string, jwt string, scope []string) ([]byte, error) { func (client *Client) FetchTokenFromAuthorizationServer(remoteUrl string, jwt string, scope []string) ([]byte, error) {
// hydra endpoint: /oauth/token // hydra endpoint: /oauth/token
data := "grant_type=" + util.URLEscape("urn:ietf:params:oauth:grant-type:jwt-bearer") + data := "grant_type=" + util.URLEscape("urn:ietf:params:oauth:grant-type:jwt-bearer") +
"&assertion=" + jwt + "&client_id=" + client.Id +
"&scope=" + strings.Join(scope, "+") "&client_secret=" + client.Secret +
"&scope=" + strings.Join(scope, "+") +
"&assertion=" + jwt
fmt.Printf("encoded params: %v\n\n", data) fmt.Printf("encoded params: %v\n\n", data)
req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer([]byte(data))) req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer([]byte(data)))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
@ -76,11 +161,14 @@ func (client *Client) FetchTokenFromAuthorizationServer(remoteUrl string, jwt st
} }
defer res.Body.Close() defer res.Body.Close()
// set flow ID back to empty string to indicate a completed flow
client.FlowId = ""
return io.ReadAll(res.Body) return io.ReadAll(res.Body)
} }
func (client *Client) AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvider, subject string, duration time.Duration, scope []string) ([]byte, error) { func (client *Client) AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvider, subject string, duration time.Duration, scope []string) ([]byte, error) {
// hydra endpoint: /admin/trust/grants/jwt-bearer/issuers // hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
if idp == nil { if idp == nil {
return nil, fmt.Errorf("identity provided is nil") return nil, fmt.Errorf("identity provided is nil")
} }
@ -88,14 +176,20 @@ func (client *Client) AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvi
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal JWK: %v", err) return nil, fmt.Errorf("failed to marshal JWK: %v", err)
} }
quotedScopes := make([]string, len(scope))
for i, s := range scope {
quotedScopes[i] = fmt.Sprintf("\"%s\"", s)
}
// NOTE: Can also include "jwks_uri" instead
data := []byte(fmt.Sprintf(`{ data := []byte(fmt.Sprintf(`{
"allow_any_subject": true, "allow_any_subject": false,
"issuer": "%s", "issuer": "%s",
"subject": "%s" "subject": "%s",
"expires_at": "%v" "expires_at": "%v",
"jwk": %v, "jwk": %v,
"scope": [ %s ], "scope": [ %s ]
}`, idp.Issuer, subject, time.Now().Add(duration), string(jwkstr), strings.Join(scope, ","))) }`, idp.Issuer, subject, time.Now().Add(duration).Format(time.RFC3339), string(jwkstr), strings.Join(quotedScopes, ",")))
fmt.Printf("%v\n", string(data))
req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer(data)) req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer(data))
// req.Header.Add("X-CSRF-Token", client.CsrfToken.Value) // req.Header.Add("X-CSRF-Token", client.CsrfToken.Value)
@ -113,6 +207,32 @@ func (client *Client) AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvi
return io.ReadAll(res.Body) return io.ReadAll(res.Body)
} }
func (client *Client) RegisterOAuthClient(registerUrl string) ([]byte, error) {
// hydra endpoint: POST /clients
data := []byte(fmt.Sprintf(`{
"client_name": "%s",
"client_secret": "%s",
"token_endpoint_auth_method": "client_secret_post",
"scope": "openid email profile",
"grant_types": ["client_credentials", "urn:ietf:params:oauth:grant-type:jwt-bearer"],
"response_types": ["token"]
}`, client.Id, client.Secret))
req, err := http.NewRequest("POST", registerUrl, bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
}
req.Header.Add("Content-Type", "application/json")
// req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken))
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to do request: %v", err)
}
defer res.Body.Close()
return io.ReadAll(res.Body)
}
func (client *Client) CreateIdentity(remoteUrl string, idToken string) ([]byte, error) { func (client *Client) CreateIdentity(remoteUrl string, idToken string) ([]byte, error) {
// kratos endpoint: /admin/identities // kratos endpoint: /admin/identities
data := []byte(`{ data := []byte(`{
@ -150,3 +270,8 @@ func (client *Client) FetchIdentities(remoteUrl string) ([]byte, error) {
return io.ReadAll(res.Body) return io.ReadAll(res.Body)
} }
func (client *Client) ClearCookies() {
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
client.Jar = jar
}

View file

@ -22,6 +22,7 @@ type Config struct {
OpenBrowser bool `yaml:"open-browser"` OpenBrowser bool `yaml:"open-browser"`
DecodeIdToken bool `yaml:"decode-id-token"` DecodeIdToken bool `yaml:"decode-id-token"`
DecodeAccessToken bool `yaml:"decode-access-token"` DecodeAccessToken bool `yaml:"decode-access-token"`
RunOnce bool `yaml:"run-once"`
} }
func NewConfig() Config { func NewConfig() Config {
@ -49,6 +50,7 @@ func NewConfig() Config {
OpenBrowser: false, OpenBrowser: false,
DecodeIdToken: false, DecodeIdToken: false,
DecodeAccessToken: false, DecodeAccessToken: false,
RunOnce: true,
} }
} }

View file

@ -19,6 +19,18 @@ func Login(config *Config) error {
server := NewServerWithConfig(config) server := NewServerWithConfig(config)
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
// initiate the login flow and get a flow ID and CSRF token
{
err := client.InitiateLoginFlow(config.ActionUrls.Login)
if err != nil {
return fmt.Errorf("failed to initiate login flow: %v", err)
}
err = client.FetchCSRFToken(config.ActionUrls.LoginFlowId)
if err != nil {
return fmt.Errorf("failed to fetch CSRF token: %v", err)
}
}
// try and fetch server configuration if provided URL // try and fetch server configuration if provided URL
idp := oidc.NewIdentityProvider() idp := oidc.NewIdentityProvider()
if config.ActionUrls.ServerConfig != "" { if config.ActionUrls.ServerConfig != "" {
@ -35,7 +47,7 @@ func Login(config *Config) error {
} }
// check if all appropriate parameters are set in config // check if all appropriate parameters are set in config
if !hasRequiredParams(config) { if !HasRequiredParams(config) {
return fmt.Errorf("client ID must be set") return fmt.Errorf("client ID must be set")
} }
@ -70,6 +82,24 @@ func Login(config *Config) error {
fmt.Printf("client did not initialize\n") fmt.Printf("client did not initialize\n")
} }
// start up another serve in background to listen for success or failures
d := make(chan []byte)
quit := make(chan bool)
var access_token []byte
go server.Serve(d)
go func() {
select {
case <-d:
fmt.Printf("got access token")
quit <- true
case <-quit:
close(d)
close(quit)
return
default:
}
}()
// use code from response and exchange for bearer token (with ID token) // use code from response and exchange for bearer token (with ID token)
tokenString, err := client.FetchTokenFromAuthenticationServer( tokenString, err := client.FetchTokenFromAuthenticationServer(
code, code,
@ -93,6 +123,7 @@ func Login(config *Config) error {
if err != nil { if err != nil {
fmt.Printf("failed to parse ID token: %v\n", err) fmt.Printf("failed to parse ID token: %v\n", err)
} else { } else {
fmt.Printf("token: %v\n", idToken)
if config.DecodeIdToken { if config.DecodeIdToken {
if err != nil { if err != nil {
fmt.Printf("failed to decode JWT: %v\n", err) fmt.Printf("failed to decode JWT: %v\n", err)
@ -117,6 +148,18 @@ func Login(config *Config) error {
// } // }
// } // }
// extract the scope from access token claims
// var scope []string
// var accessJsonPayload map[string]any
// var accessJwtPayload []byte = accessJwtSegments[1]
// if accessJsonPayload != nil {
// err := json.Unmarshal(accessJwtPayload, &accessJsonPayload)
// if err != nil {
// return fmt.Errorf("failed to unmarshal JWT: %v", err)
// }
// scope = idJsonPayload["scope"].([]string)
// }
// create a new identity with identity and session manager if url is provided // create a new identity with identity and session manager if url is provided
if config.ActionUrls.Identities != "" { if config.ActionUrls.Identities != "" {
fmt.Printf("Attempting to create a new identity...\n") fmt.Printf("Attempting to create a new identity...\n")
@ -145,44 +188,47 @@ func Login(config *Config) error {
return fmt.Errorf("failed to extract subject from ID token claims") return fmt.Errorf("failed to extract subject from ID token claims")
} }
// extract the scope from access token claims
// var scope []string
// var accessJsonPayload map[string]any
// var accessJwtPayload []byte = accessJwtSegments[1]
// if accessJsonPayload != nil {
// err := json.Unmarshal(accessJwtPayload, &accessJsonPayload)
// if err != nil {
// return fmt.Errorf("failed to unmarshal JWT: %v", err)
// }
// scope = idJsonPayload["scope"].([]string)
// }
// fetch JWKS and add issuer to authentication server to submit ID token // fetch JWKS and add issuer to authentication server to submit ID token
fmt.Printf("Fetching JWKS from authentication server for verification...\n") fmt.Printf("Fetching JWKS from authentication server for verification...\n")
err = idp.FetchJwk(config.ActionUrls.JwksUri) err = idp.FetchJwk(config.ActionUrls.JwksUri)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch JWK: %v\n", err) return fmt.Errorf("failed to fetch JWK: %v", err)
} else { } else {
fmt.Printf("Attempting to add issuer to authorization server...\n") fmt.Printf("Attempting to add issuer to authorization server...\n")
res, err := client.AddTrustedIssuer(config.ActionUrls.TrustedIssuers, idp, subject, time.Duration(1000), config.Scope) res, err := client.AddTrustedIssuer(config.ActionUrls.TrustedIssuers, idp, subject, time.Duration(1000), config.Scope)
if err != nil { if err != nil {
return fmt.Errorf("failed to add trusted issuer: %v", err) return fmt.Errorf("failed to add trusted issuer: %v", err)
} }
if string(res) == "" { fmt.Printf("%v\n", string(res))
fmt.Printf("Added issuer to authorization server successfully.\n") }
}
// try and register a new client with authorization server
res, err := client.RegisterOAuthClient("http://127.0.0.1:4445/clients")
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
// extract the client info from response
var clientData map[string]any
err = json.Unmarshal(res, &clientData)
if err != nil {
return fmt.Errorf("failed to unmarshal client data: %v", err)
} else {
client.Id = clientData["client_id"].(string)
client.Secret = clientData["client_secret"].(string)
} }
// use ID token/user info to fetch access token from authentication server // use ID token/user info to fetch access token from authentication server
if config.ActionUrls.AccessToken != "" { if config.ActionUrls.AccessToken != "" {
fmt.Printf("Fetching access token from authorization server...\n") fmt.Printf("Fetching access token from authorization server...\n")
accessToken, err := client.FetchTokenFromAuthorizationServer(config.ActionUrls.AccessToken, idToken, config.Scope) res, err := client.FetchTokenFromAuthorizationServer(config.ActionUrls.AccessToken, idToken, config.Scope)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch access token: %v", err) return fmt.Errorf("failed to fetch access token: %v", err)
} }
fmt.Printf("%s\n", accessToken) fmt.Printf("%s\n", res)
} }
fmt.Printf("Success!") d <- access_token
return nil return nil
} }

View file

@ -6,8 +6,10 @@ type ActionUrls struct {
AccessToken string `yaml:"access-token"` AccessToken string `yaml:"access-token"`
ServerConfig string `yaml:"server-config"` ServerConfig string `yaml:"server-config"`
JwksUri string `yaml:"jwks_uri"` JwksUri string `yaml:"jwks_uri"`
Login string `yaml:"login"`
LoginFlowId string `yaml:"login-flow-id"`
} }
func hasRequiredParams(config *Config) bool { func HasRequiredParams(config *Config) bool {
return config.Client.Id != "" && config.Client.Secret != "" return config.Client.Id != "" && config.Client.Secret != ""
} }

View file

@ -5,10 +5,13 @@ import (
"net/http" "net/http"
"os" "os"
"strings" "strings"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/chi/v5"
) )
type Server struct { type Server struct {
http.Server *http.Server
Host string `yaml:"host"` Host string `yaml:"host"`
Port int `yaml:"port"` Port int `yaml:"port"`
} }
@ -17,16 +20,16 @@ func NewServerWithConfig(config *Config) *Server {
host := config.Server.Host host := config.Server.Host
port := config.Server.Port port := config.Server.Port
server := &Server{ server := &Server{
Server: &http.Server{
Addr: fmt.Sprintf("%s:%d", host, port),
},
Host: host, Host: host,
Port: port, Port: port,
} }
server.Addr = fmt.Sprintf("%s:%d", host, port)
return server return server
} }
func (s *Server) SetListenAddr(host string, port int) { func (s *Server) SetListenAddr(host string, port int) {
s.Host = host
s.Port = port
s.Addr = s.GetListenAddr() s.Addr = s.GetListenAddr()
} }
@ -36,34 +39,69 @@ func (s *Server) GetListenAddr() string {
func (s *Server) WaitForAuthorizationCode(loginUrl string) (string, error) { func (s *Server) WaitForAuthorizationCode(loginUrl string) (string, error) {
var code string var code string
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { r := chi.NewRouter()
r.Use(middleware.RedirectSlashes)
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
}) })
http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { r.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
// show login page with notice to redirect // show login page with notice to redirect
loginPage, err := os.ReadFile("pages/index.html") loginPage, err := os.ReadFile("pages/index.html")
if err != nil { if err != nil {
fmt.Printf("failed to load login page: %v\n", err) fmt.Printf("failed to load login page: %v\n", err)
} }
loginPage = []byte(strings.ReplaceAll(string(loginPage), "{{loginUrl}}", loginUrl)) loginPage = []byte(strings.ReplaceAll(string(loginPage), "{{loginUrl}}", loginUrl))
w.WriteHeader(http.StatusSeeOther)
w.Write(loginPage) w.Write(loginPage)
}) })
http.HandleFunc("/oidc/callback", func(w http.ResponseWriter, r *http.Request) { r.HandleFunc("/oidc/callback", func(w http.ResponseWriter, r *http.Request) {
// get the code from the OIDC provider // get the code from the OIDC provider
if r != nil { if r != nil {
code = r.URL.Query().Get("code") code = r.URL.Query().Get("code")
fmt.Printf("Authorization code: %v\n", code) fmt.Printf("Authorization code: %v\n", code)
} }
http.Redirect(w, r, s.Addr+"/success", http.StatusSeeOther) http.Redirect(w, r, "/redirect", http.StatusSeeOther)
s.Close()
}) })
r.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
err := s.Close()
if err != nil {
fmt.Printf("failed to close server: %v\n", err)
}
})
s.Handler = r
return code, s.ListenAndServe() return code, s.ListenAndServe()
} }
func (s *Server) ShowSuccessPage() error { func (s *Server) Serve(data chan []byte) error {
http.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) { output, ok := <-data
if !ok {
return fmt.Errorf("failed to receive data")
}
fmt.Printf("Received data: %v\n", string(output))
// http.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
// })
r := chi.NewRouter()
r.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("Serving success page.")
successPage, err := os.ReadFile("pages/success.html")
if err != nil {
fmt.Printf("failed to load success page: %v\n", err)
}
successPage = []byte(strings.ReplaceAll(string(successPage), "{{access_token}}", string(output)))
w.Write(successPage)
}) })
r.HandleFunc("/error", func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("Serving error page.")
errorPage, err := os.ReadFile("pages/success.html")
if err != nil {
fmt.Printf("failed to load success page: %v\n", err)
}
// errorPage = []byte(strings.ReplaceAll(string(errorPage), "{{access_token}}", output))
w.Write(errorPage)
})
s.Handler = r
return s.ListenAndServe() return s.ListenAndServe()
} }

View file

@ -92,3 +92,11 @@ func GetCommit() string {
} }
return string(bytes) return string(bytes)
} }
func Tokenize(s string) map[string]any {
tokens := make(map[string]any)
// find token enclosed in curly brackets
return tokens
}