mirror of
https://github.com/davidallendj/opaal.git
synced 2025-12-20 03:27:02 -07:00
Refactored, reorganized, fixed issues, and implemented functionality
This commit is contained in:
parent
3edc5e1191
commit
5428085fdf
9 changed files with 524 additions and 190 deletions
|
|
@ -1,99 +1,13 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"davidallendj/opaal/internal/oauth"
|
opaal "davidallendj/opaal/internal"
|
||||||
"davidallendj/opaal/internal/oidc"
|
|
||||||
"davidallendj/opaal/internal/util"
|
"davidallendj/opaal/internal/util"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
yaml "gopkg.in/yaml.v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
Host string `yaml:"host"`
|
|
||||||
Port int `yaml:"port"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthEndpoints struct {
|
|
||||||
Identities string `yaml:"identities"`
|
|
||||||
TrustedIssuers string `yaml:"trusted-issuers"`
|
|
||||||
AccessToken string `yaml:"access-token"`
|
|
||||||
ServerConfig string `yaml:"server-config"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Server Server `yaml:"server"`
|
|
||||||
Client oauth.Client `yaml:"client"`
|
|
||||||
IdentityProvider oidc.IdentityProvider `yaml:"oidc"`
|
|
||||||
State string `yaml:"state"`
|
|
||||||
ResponseType string `yaml:"response-type"`
|
|
||||||
Scope []string `yaml:"scope"`
|
|
||||||
AuthEndpoints AuthEndpoints `yaml:"urls"`
|
|
||||||
OpenBrowser bool `yaml:"open-browser"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConfig() Config {
|
|
||||||
return Config{
|
|
||||||
Server: Server{
|
|
||||||
Host: "127.0.0.1",
|
|
||||||
Port: 3333,
|
|
||||||
},
|
|
||||||
Client: oauth.Client{
|
|
||||||
Id: "",
|
|
||||||
Secret: "",
|
|
||||||
RedirectUris: []string{""},
|
|
||||||
},
|
|
||||||
IdentityProvider: *oidc.NewIdentityProvider(),
|
|
||||||
State: util.RandomString(20),
|
|
||||||
ResponseType: "code",
|
|
||||||
Scope: []string{"openid", "profile", "email"},
|
|
||||||
AuthEndpoints: AuthEndpoints{
|
|
||||||
Identities: "",
|
|
||||||
AccessToken: "",
|
|
||||||
TrustedIssuers: "",
|
|
||||||
ServerConfig: "",
|
|
||||||
},
|
|
||||||
OpenBrowser: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfig(path string) Config {
|
|
||||||
var c Config = NewConfig()
|
|
||||||
file, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to read config file: %v\n", err)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
err = yaml.Unmarshal(file, &c)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed to unmarshal config: %v\n", err)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func SaveDefaultConfig(path string) {
|
|
||||||
path = filepath.Clean(path)
|
|
||||||
if path == "" || path == "." {
|
|
||||||
path = "config.yaml"
|
|
||||||
}
|
|
||||||
var c = NewConfig()
|
|
||||||
data, err := yaml.Marshal(c)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to marshal config: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = os.WriteFile(path, data, os.ModePerm)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to write default config file: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var configCmd = &cobra.Command{
|
var configCmd = &cobra.Command{
|
||||||
Use: "config",
|
Use: "config",
|
||||||
Short: "Create a new default config file",
|
Short: "Create a new default config file",
|
||||||
|
|
@ -105,7 +19,7 @@ var configCmd = &cobra.Command{
|
||||||
fmt.Printf("file or directory exists\n")
|
fmt.Printf("file or directory exists\n")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
SaveDefaultConfig(path)
|
opaal.SaveDefaultConfig(path)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
104
cmd/login.go
104
cmd/login.go
|
|
@ -1,22 +1,14 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"davidallendj/opal/internal/api"
|
opaal "davidallendj/opaal/internal"
|
||||||
"davidallendj/opal/internal/oidc"
|
"davidallendj/opaal/internal/util"
|
||||||
"davidallendj/opal/internal/util"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
func hasRequiredParams(config *Config) bool {
|
|
||||||
return config.Client.Id != "" && config.Client.Secret != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var loginCmd = &cobra.Command{
|
var loginCmd = &cobra.Command{
|
||||||
Use: "login",
|
Use: "login",
|
||||||
Short: "Start the login flow",
|
Short: "Start the login flow",
|
||||||
|
|
@ -28,95 +20,15 @@ var loginCmd = &cobra.Command{
|
||||||
fmt.Printf("failed to load config")
|
fmt.Printf("failed to load config")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
} else if exists {
|
} else if exists {
|
||||||
config = LoadConfig(configPath)
|
config = opaal.LoadConfig(configPath)
|
||||||
} else {
|
} else {
|
||||||
config = NewConfig()
|
config = opaal.NewConfig()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// try and fetch server configuration if provided URL
|
err := opaal.Login(&config)
|
||||||
idp := oidc.NewIdentityProvider()
|
|
||||||
if config.AuthEndpoints.ServerConfig != "" {
|
|
||||||
idp.FetchServerConfig(config.AuthEndpoints.ServerConfig)
|
|
||||||
} else {
|
|
||||||
// otherwise, use what's provided in config file
|
|
||||||
idp.Issuer = config.IdentityProvider.Issuer
|
|
||||||
idp.Endpoints = config.IdentityProvider.Endpoints
|
|
||||||
idp.Supported = config.IdentityProvider.Supported
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if all appropriate parameters are set in config
|
|
||||||
if !hasRequiredParams(&config) {
|
|
||||||
fmt.Printf("client ID must be set\n")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// build the authorization URL to redirect user for social sign-in
|
|
||||||
var authorizationUrl = util.BuildAuthorizationUrl(
|
|
||||||
idp.Endpoints.Authorize,
|
|
||||||
config.Client.Id,
|
|
||||||
config.Client.RedirectUris,
|
|
||||||
config.State,
|
|
||||||
config.ResponseType,
|
|
||||||
config.Scope,
|
|
||||||
)
|
|
||||||
|
|
||||||
// print the authorization URL for sharing
|
|
||||||
serverAddr := fmt.Sprintf("%s:%d", config.IdentityProvider.Issuer)
|
|
||||||
fmt.Printf(`Login with identity provider:
|
|
||||||
%s/login
|
|
||||||
%s\n`,
|
|
||||||
serverAddr, authorizationUrl,
|
|
||||||
)
|
|
||||||
|
|
||||||
// automatically open browser to initiate login flow (only useful for testing)
|
|
||||||
if config.OpenBrowser {
|
|
||||||
util.OpenUrl(authorizationUrl)
|
|
||||||
}
|
|
||||||
|
|
||||||
// authorize oauth client and listen for callback from provider
|
|
||||||
fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", serverAddr)
|
|
||||||
code, err := api.WaitForAuthorizationCode(serverAddr, authorizationUrl)
|
|
||||||
if errors.Is(err, http.ErrServerClosed) {
|
|
||||||
fmt.Printf("Server closed.\n")
|
|
||||||
} else if err != nil {
|
|
||||||
fmt.Printf("Error starting server: %s\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// use code from response and exchange for bearer token (with ID token)
|
|
||||||
tokenString, err := api.FetchIssuerToken(
|
|
||||||
code,
|
|
||||||
idp.Endpoints.Token,
|
|
||||||
config.Client,
|
|
||||||
config.State,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("%v\n", err)
|
fmt.Print(err)
|
||||||
return
|
os.Exit(1)
|
||||||
}
|
|
||||||
|
|
||||||
// extract ID token from bearer as JSON string for easy consumption
|
|
||||||
var data map[string]any
|
|
||||||
json.Unmarshal([]byte(tokenString), &data)
|
|
||||||
idToken := data["id_token"].(string)
|
|
||||||
|
|
||||||
// create a new identity with identity and session manager if url is provided
|
|
||||||
if config.AuthEndpoints.Identities != "" {
|
|
||||||
api.CreateIdentity(config.AuthEndpoints.Identities, idToken)
|
|
||||||
api.FetchIdentities(config.AuthEndpoints.Identities)
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetch JWKS and add issuer to authentication server to submit ID token
|
|
||||||
err = idp.FetchJwk("")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("failed to fetch JWK: %v\n", err)
|
|
||||||
} else {
|
|
||||||
api.AddTrustedIssuer(config.AuthEndpoints.TrustedIssuers, idp.Key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// use ID token/user info to fetch access token from authentication server
|
|
||||||
if config.AuthEndpoints.AccessToken != "" {
|
|
||||||
api.FetchAccessToken(config.AuthEndpoints.AccessToken, config.Client.Id, idToken, config.Scope)
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -131,5 +43,7 @@ func init() {
|
||||||
loginCmd.Flags().StringVar(&config.Server.Host, "host", config.Server.Host, "set the listening host")
|
loginCmd.Flags().StringVar(&config.Server.Host, "host", config.Server.Host, "set the listening host")
|
||||||
loginCmd.Flags().IntVar(&config.Server.Port, "port", config.Server.Port, "set the listening port")
|
loginCmd.Flags().IntVar(&config.Server.Port, "port", config.Server.Port, "set the listening port")
|
||||||
loginCmd.Flags().BoolVar(&config.OpenBrowser, "open-browser", config.OpenBrowser, "automatically open link in browser")
|
loginCmd.Flags().BoolVar(&config.OpenBrowser, "open-browser", config.OpenBrowser, "automatically open link in browser")
|
||||||
|
loginCmd.Flags().BoolVar(&config.DecodeIdToken, "decode-id-token", config.DecodeIdToken, "decode and print ID token from identity provider")
|
||||||
|
loginCmd.Flags().BoolVar(&config.DecodeAccessToken, "decore-access-token", config.DecodeAccessToken, "decode and print access token from authorization server")
|
||||||
rootCmd.AddCommand(loginCmd)
|
rootCmd.AddCommand(loginCmd)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
opaal "davidallendj/opaal/internal"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
|
@ -9,7 +10,7 @@ import (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
configPath = ""
|
configPath = ""
|
||||||
config Config
|
config opaal.Config
|
||||||
)
|
)
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
Use: "oidc",
|
Use: "oidc",
|
||||||
|
|
|
||||||
87
internal/config.go
Normal file
87
internal/config.go
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
package opaal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"davidallendj/opaal/internal/oauth"
|
||||||
|
"davidallendj/opaal/internal/oidc"
|
||||||
|
"davidallendj/opaal/internal/util"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Version string `yaml:"version"`
|
||||||
|
Server Server `yaml:"server"`
|
||||||
|
Client oauth.Client `yaml:"client"`
|
||||||
|
IdentityProvider oidc.IdentityProvider `yaml:"oidc"`
|
||||||
|
State string `yaml:"state"`
|
||||||
|
ResponseType string `yaml:"response-type"`
|
||||||
|
Scope []string `yaml:"scope"`
|
||||||
|
ActionUrls ActionUrls `yaml:"urls"`
|
||||||
|
OpenBrowser bool `yaml:"open-browser"`
|
||||||
|
DecodeIdToken bool `yaml:"decode-id-token"`
|
||||||
|
DecodeAccessToken bool `yaml:"decode-access-token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfig() Config {
|
||||||
|
return Config{
|
||||||
|
Version: util.GetCommit(),
|
||||||
|
Server: Server{
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 3333,
|
||||||
|
},
|
||||||
|
Client: oauth.Client{
|
||||||
|
Id: "",
|
||||||
|
Secret: "",
|
||||||
|
RedirectUris: []string{""},
|
||||||
|
},
|
||||||
|
IdentityProvider: *oidc.NewIdentityProvider(),
|
||||||
|
State: util.RandomString(20),
|
||||||
|
ResponseType: "code",
|
||||||
|
Scope: []string{"openid", "profile", "email"},
|
||||||
|
ActionUrls: ActionUrls{
|
||||||
|
Identities: "",
|
||||||
|
AccessToken: "",
|
||||||
|
TrustedIssuers: "",
|
||||||
|
ServerConfig: "",
|
||||||
|
},
|
||||||
|
OpenBrowser: false,
|
||||||
|
DecodeIdToken: false,
|
||||||
|
DecodeAccessToken: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfig(path string) Config {
|
||||||
|
var c Config = NewConfig()
|
||||||
|
file, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to read config file: %v\n", err)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
err = yaml.Unmarshal(file, &c)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to unmarshal config: %v\n", err)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveDefaultConfig(path string) {
|
||||||
|
path = filepath.Clean(path)
|
||||||
|
if path == "" || path == "." {
|
||||||
|
path = "config.yaml"
|
||||||
|
}
|
||||||
|
var c = NewConfig()
|
||||||
|
data, err := yaml.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to marshal config: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = os.WriteFile(path, data, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to write default config file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -3,7 +3,7 @@ package oauth
|
||||||
type Client struct {
|
type Client struct {
|
||||||
Id string `yaml:"id"`
|
Id string `yaml:"id"`
|
||||||
Secret string `yaml:"secret"`
|
Secret string `yaml:"secret"`
|
||||||
RedirectUris []string `yaml:"redirect_uris"`
|
RedirectUris []string `yaml:"redirect-uris"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient() *Client {
|
func NewClient() *Client {
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,13 @@
|
||||||
package oidc
|
package oidc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/jwk"
|
"github.com/lestrrat-go/jwx/jwk"
|
||||||
)
|
)
|
||||||
|
|
@ -15,7 +20,7 @@ type IdentityProvider struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Endpoints struct {
|
type Endpoints struct {
|
||||||
Authorize string `json:"authorize_endpoint" yaml:"authorize"`
|
Authorization string `json:"authorization_endpoint" yaml:"authorization"`
|
||||||
Token string `json:"token_endpoint" yaml:"token"`
|
Token string `json:"token_endpoint" yaml:"token"`
|
||||||
Revocation string `json:"revocation_endpoint" yaml:"revocation"`
|
Revocation string `json:"revocation_endpoint" yaml:"revocation"`
|
||||||
Introspection string `json:"introspection_endpoint" yaml:"introspection"`
|
Introspection string `json:"introspection_endpoint" yaml:"introspection"`
|
||||||
|
|
@ -36,7 +41,7 @@ type Supported struct {
|
||||||
func NewIdentityProvider() *IdentityProvider {
|
func NewIdentityProvider() *IdentityProvider {
|
||||||
p := &IdentityProvider{Issuer: "127.0.0.1"}
|
p := &IdentityProvider{Issuer: "127.0.0.1"}
|
||||||
p.Endpoints = Endpoints{
|
p.Endpoints = Endpoints{
|
||||||
Authorize: p.Issuer + "/oauth/authorize",
|
Authorization: p.Issuer + "/oauth/authorize",
|
||||||
Token: p.Issuer + "/oauth/token",
|
Token: p.Issuer + "/oauth/token",
|
||||||
Revocation: p.Issuer + "/oauth/revocation",
|
Revocation: p.Issuer + "/oauth/revocation",
|
||||||
Introspection: p.Issuer + "/oauth/introspect",
|
Introspection: p.Issuer + "/oauth/introspect",
|
||||||
|
|
@ -69,12 +74,70 @@ func NewIdentityProvider() *IdentityProvider {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *IdentityProvider) FetchServerConfig(url string) {
|
func (p *IdentityProvider) ParseServerConfig(data []byte) error {
|
||||||
|
// parse JSON into IdentityProvider fields
|
||||||
|
var ep Endpoints
|
||||||
|
var s Supported
|
||||||
|
var e error
|
||||||
|
epErr := json.Unmarshal(data, &ep)
|
||||||
|
if epErr != nil {
|
||||||
|
e = fmt.Errorf("%v", epErr)
|
||||||
|
}
|
||||||
|
sErr := json.Unmarshal(data, &s)
|
||||||
|
if sErr != nil {
|
||||||
|
e = fmt.Errorf("%v %v", e, sErr)
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(data, p)
|
||||||
|
if err != nil {
|
||||||
|
e = fmt.Errorf("%v %v", e, err)
|
||||||
|
}
|
||||||
|
p.Endpoints = ep
|
||||||
|
p.Supported = s
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *IdentityProvider) LoadServerConfig(path string) error {
|
||||||
|
// load server config from local file i
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read server config: %v", err)
|
||||||
|
}
|
||||||
|
err = p.ParseServerConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse server config: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *IdentityProvider) FetchServerConfig(url string) error {
|
||||||
// make a request to a server's openid-configuration
|
// make a request to a server's openid-configuration
|
||||||
|
req, err := http.NewRequest("GET", url, bytes.NewBuffer([]byte{}))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create a new request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to do request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
err = p.ParseServerConfig(body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse server config: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *IdentityProvider) FetchJwk(url string) error {
|
func (p *IdentityProvider) FetchJwk(url string) error {
|
||||||
//
|
if url == "" {
|
||||||
|
url = p.Endpoints.Jwks
|
||||||
|
}
|
||||||
|
// fetch JWKS from identity provider
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
set, err := jwk.Fetch(ctx, url)
|
set, err := jwk.Fetch(ctx, url)
|
||||||
|
|
|
||||||
334
internal/opaal.go
Normal file
334
internal/opaal.go
Normal file
|
|
@ -0,0 +1,334 @@
|
||||||
|
package opaal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"davidallendj/opaal/internal/oauth"
|
||||||
|
"davidallendj/opaal/internal/oidc"
|
||||||
|
"davidallendj/opaal/internal/util"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
Host string `yaml:"host"`
|
||||||
|
Port int `yaml:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ActionUrls struct {
|
||||||
|
Identities string `yaml:"identities"`
|
||||||
|
TrustedIssuers string `yaml:"trusted-issuers"`
|
||||||
|
AccessToken string `yaml:"access-token"`
|
||||||
|
ServerConfig string `yaml:"server-config"`
|
||||||
|
JwksUri string `yaml:"jwks_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Login(config *Config) error {
|
||||||
|
if config == nil {
|
||||||
|
return fmt.Errorf("config is not valid")
|
||||||
|
}
|
||||||
|
// try and fetch server configuration if provided URL
|
||||||
|
idp := oidc.NewIdentityProvider()
|
||||||
|
if config.ActionUrls.ServerConfig != "" {
|
||||||
|
fmt.Printf("Fetching server configuration: %s\n", config.ActionUrls.ServerConfig)
|
||||||
|
err := idp.FetchServerConfig(config.ActionUrls.ServerConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch server config: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// otherwise, use what's provided in config file
|
||||||
|
idp.Issuer = config.IdentityProvider.Issuer
|
||||||
|
idp.Endpoints = config.IdentityProvider.Endpoints
|
||||||
|
idp.Supported = config.IdentityProvider.Supported
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if all appropriate parameters are set in config
|
||||||
|
if !hasRequiredParams(config) {
|
||||||
|
return fmt.Errorf("client ID must be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the authorization URL to redirect user for social sign-in
|
||||||
|
var authorizationUrl = util.BuildAuthorizationUrl(
|
||||||
|
idp.Endpoints.Authorization,
|
||||||
|
config.Client.Id,
|
||||||
|
config.Client.RedirectUris,
|
||||||
|
config.State,
|
||||||
|
config.ResponseType,
|
||||||
|
config.Scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
// print the authorization URL for sharing
|
||||||
|
serverAddr := fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port)
|
||||||
|
fmt.Printf("Login with identity provider:\n\n %s/login\n %s\n\n",
|
||||||
|
serverAddr, authorizationUrl,
|
||||||
|
)
|
||||||
|
|
||||||
|
// automatically open browser to initiate login flow (only useful for testing)
|
||||||
|
if config.OpenBrowser {
|
||||||
|
util.OpenUrl(authorizationUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// authorize oauth client and listen for callback from provider
|
||||||
|
fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", serverAddr)
|
||||||
|
code, err := WaitForAuthorizationCode(serverAddr, authorizationUrl)
|
||||||
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
|
fmt.Printf("Server closed.\n")
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// use code from response and exchange for bearer token (with ID token)
|
||||||
|
tokenString, err := FetchIssuerToken(
|
||||||
|
code,
|
||||||
|
idp.Endpoints.Token,
|
||||||
|
config.Client,
|
||||||
|
config.State,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch token from issuer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshal data to get id_token and access_token
|
||||||
|
var data map[string]any
|
||||||
|
err = json.Unmarshal([]byte(tokenString), &data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract ID token from bearer as JSON string for easy consumption
|
||||||
|
idToken := data["id_token"].(string)
|
||||||
|
idJwtSegments, err := util.DecodeJwt(idToken)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to parse ID token: %v\n", err)
|
||||||
|
} else {
|
||||||
|
if config.DecodeIdToken {
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to decode JWT: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("id_token.header: %s\nid_token.payload: %s\n", string(idJwtSegments[0]), string(idJwtSegments[1]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the access token to get the scopes
|
||||||
|
// accessToken := data["access_token"].(string)
|
||||||
|
// accessJwtSegments, err := util.DecodeJwt(accessToken)
|
||||||
|
// if err != nil || len(accessJwtSegments) <= {
|
||||||
|
// fmt.Printf("failed to parse access token: %v\n", err)
|
||||||
|
// } else {
|
||||||
|
// if config.DecodeIdToken {
|
||||||
|
// if err != nil {
|
||||||
|
// fmt.Printf("failed to decode JWT: %v\n", err)
|
||||||
|
// } else {
|
||||||
|
// fmt.Printf("access_token.header: %s\naccess_token.payload: %s\n", string(accessJwtSegments[0]), string(accessJwtSegments[1]))
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// create a new identity with identity and session manager if url is provided
|
||||||
|
if config.ActionUrls.Identities != "" {
|
||||||
|
CreateIdentity(config.ActionUrls.Identities, idToken)
|
||||||
|
FetchIdentities(config.ActionUrls.Identities)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the subject from ID token claims
|
||||||
|
var subject string
|
||||||
|
var idJsonPayload map[string]any
|
||||||
|
var idJwtPayload []byte = idJwtSegments[1]
|
||||||
|
if idJwtPayload != nil {
|
||||||
|
err := json.Unmarshal(idJwtPayload, &idJsonPayload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal JWT: %v", err)
|
||||||
|
}
|
||||||
|
subject = idJsonPayload["sub"].(string)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("failed to extract subject from ID token claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the scope from access token claims
|
||||||
|
// var scope []string
|
||||||
|
// var accessJsonPayload map[string]any
|
||||||
|
// var accessJwtPayload []byte = accessJwtSegments[1]
|
||||||
|
// if accessJsonPayload != nil {
|
||||||
|
// err := json.Unmarshal(accessJwtPayload, &accessJsonPayload)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to unmarshal JWT: %v", err)
|
||||||
|
// }
|
||||||
|
// scope = idJsonPayload["scope"].([]string)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// fetch JWKS and add issuer to authentication server to submit ID token
|
||||||
|
fmt.Printf("Fetching JWKS for verification...\n")
|
||||||
|
err = idp.FetchJwk(config.ActionUrls.JwksUri)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to fetch JWK: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Attempting to add issuer to authorization server...\n")
|
||||||
|
err = AddTrustedIssuer(config.ActionUrls.TrustedIssuers, *idp, subject, time.Duration(1000), config.Scope)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add trusted issuer: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// use ID token/user info to fetch access token from authentication server
|
||||||
|
if config.ActionUrls.AccessToken != "" {
|
||||||
|
fmt.Printf("Fetching access token from authorization server...")
|
||||||
|
accessToken, err := FetchAccessToken(config.ActionUrls.AccessToken, config.Client.Id, idToken, config.Scope)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch access token: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%s\n", accessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Success!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WaitForAuthorizationCode(serverAddr string, loginUrl string) (string, error) {
|
||||||
|
var code string
|
||||||
|
s := &http.Server{Addr: serverAddr}
|
||||||
|
http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// redirect directly to identity provider with this endpoint
|
||||||
|
http.Redirect(w, r, loginUrl, http.StatusSeeOther)
|
||||||
|
})
|
||||||
|
http.HandleFunc("/oidc/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// get the code from the OIDC provider
|
||||||
|
if r != nil {
|
||||||
|
code = r.URL.Query().Get("code")
|
||||||
|
fmt.Printf("Authorization code: %v\n", code)
|
||||||
|
}
|
||||||
|
s.Close()
|
||||||
|
})
|
||||||
|
return code, s.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchIssuerToken(code string, remoteUrl string, client oauth.Client, state string) (string, error) {
|
||||||
|
var token string
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"client_id": {client.Id},
|
||||||
|
"client_secret": {client.Secret},
|
||||||
|
"state": {state},
|
||||||
|
"redirect_uri": {strings.Join(client.RedirectUris, ",")},
|
||||||
|
}
|
||||||
|
res, err := http.PostForm(remoteUrl, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get ID token: %s", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
b, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
token = string(b)
|
||||||
|
|
||||||
|
fmt.Printf("%v\n", token)
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchAccessToken(remoteUrl string, clientId string, jwt string, scopes []string) (string, error) {
|
||||||
|
// hydra endpoint: /oauth/token
|
||||||
|
var token string
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
|
||||||
|
"assertion": {jwt},
|
||||||
|
}
|
||||||
|
res, err := http.PostForm(remoteUrl, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get token: %s", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
b, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
token = string(b)
|
||||||
|
|
||||||
|
fmt.Printf("%v\n", token)
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddTrustedIssuer(remoteUrl string, idp oidc.IdentityProvider, subject string, duration time.Duration, scope []string) error {
|
||||||
|
// hydra endpoint: /admin/trust/grants/jwt-bearer/issuers
|
||||||
|
jwkstr, err := json.Marshal(idp.Key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal JWK: %v", err)
|
||||||
|
}
|
||||||
|
data := []byte(fmt.Sprintf(`{
|
||||||
|
"allow_any_subject": false,
|
||||||
|
"issuer": "%s",
|
||||||
|
"subject": "%s"
|
||||||
|
"expires_at": "%v"
|
||||||
|
"jwk": %v,
|
||||||
|
"scope": [ j%s ],
|
||||||
|
}`, idp.Issuer, subject, time.Now().Add(duration), string(jwkstr), strings.Join(scope, ",")))
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer(data))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create a new request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
// req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken))
|
||||||
|
client := &http.Client{}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to do request: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%d\n", res.StatusCode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateIdentity(remoteUrl string, idToken string) error {
|
||||||
|
// kratos endpoint: /admin/identities
|
||||||
|
data := []byte(`{
|
||||||
|
"schema_id": "preset://email",
|
||||||
|
"traits": {
|
||||||
|
"email": "docs@example.org"
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer(data))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create a new request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken))
|
||||||
|
client := &http.Client{}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to do request: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%d\n", res.StatusCode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchIdentities(remoteUrl string) error {
|
||||||
|
req, err := http.NewRequest("GET", remoteUrl, bytes.NewBuffer([]byte{}))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create a new request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to do request: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%v\n", res)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RedirectSuccess() {
|
||||||
|
// show a success page with the user's access token
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasRequiredParams(config *Config) bool {
|
||||||
|
return config.Client.Id != "" && config.Client.Secret != ""
|
||||||
|
}
|
||||||
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RandomString(n int) string {
|
func RandomString(n int) string {
|
||||||
|
|
@ -50,6 +52,17 @@ func EncodeBase64(s string) string {
|
||||||
return base64.StdEncoding.EncodeToString([]byte(s))
|
return base64.StdEncoding.EncodeToString([]byte(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DecodeJwt(encoded string) ([][]byte, error) {
|
||||||
|
// split the string into 3 segments and decode
|
||||||
|
segments := strings.Split(encoded, ".")
|
||||||
|
decoded := [][]byte{}
|
||||||
|
for _, segment := range segments {
|
||||||
|
bytes, _ := jwt.DecodeSegment(segment)
|
||||||
|
decoded = append(decoded, bytes)
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
func PathExists(path string) (bool, error) {
|
func PathExists(path string) (bool, error) {
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
@ -79,3 +92,11 @@ func OpenUrl(url string) error {
|
||||||
args = append(args, url)
|
args = append(args, url)
|
||||||
return exec.Command(cmd, args...).Start()
|
return exec.Command(cmd, args...).Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetCommit() string {
|
||||||
|
bytes, err := exec.Command("git", "rev --parse HEAD").Output()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(bytes)
|
||||||
|
}
|
||||||
|
|
|
||||||
2
main.go
2
main.go
|
|
@ -1,6 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "davidallendj/oidc-auth/cmd"
|
import "davidallendj/opaal/cmd"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
userDB = ""
|
userDB = ""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue