Major 'internal' package refactor

This commit is contained in:
David Allen 2024-08-07 10:59:10 -06:00
parent 2c841906b2
commit 6d1dae25ec
No known key found for this signature in database
GPG key ID: 717C593FF60A2ACC
9 changed files with 359 additions and 191 deletions

10
internal/cache/cache.go vendored Normal file
View file

@ -0,0 +1,10 @@
package cache
import "database/sql/driver"
type Cache[T any] interface {
CreateIfNotExists(path string) (driver.Connector, error)
Insert(path string, data ...T) error
Delete(path string, data ...T) error
Get(path string) ([]T, error)
}

97
internal/cache/sqlite/sqlite.go vendored Normal file
View file

@ -0,0 +1,97 @@
package sqlite
import (
"fmt"
magellan "github.com/OpenCHAMI/magellan/internal"
"github.com/jmoiron/sqlx"
)
const TABLE_NAME = "magellan_scanned_assets"
func CreateScannedAssetIfNotExists(path string) (*sqlx.DB, error) {
schema := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
host TEXT NOT NULL,
port INTEGER NOT NULL,
protocol TEXT,
state INTEGER,
timestamp TIMESTAMP,
PRIMARY KEY (host, port)
);
`, TABLE_NAME)
// TODO: it may help with debugging to check for file permissions here first
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return nil, fmt.Errorf("failed to open database: %v", err)
}
db.MustExec(schema)
return db, nil
}
func InsertScannedAssets(path string, assets ...magellan.ScannedAsset) error {
if assets == nil {
return fmt.Errorf("states == nil")
}
// create database if it doesn't already exist
db, err := CreateScannedAssetIfNotExists(path)
if err != nil {
return err
}
// insert all probe states into db
tx := db.MustBegin()
for _, state := range assets {
sql := fmt.Sprintf(`INSERT OR REPLACE INTO %s (host, port, protocol, state, timestamp)
VALUES (:host, :port, :protocol, :state, :timestamp);`, TABLE_NAME)
_, err := tx.NamedExec(sql, &state)
if err != nil {
fmt.Printf("failed to execute transaction: %v\n", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %v", err)
}
return nil
}
func DeleteScannedAssets(path string, results ...magellan.ScannedAsset) error {
if results == nil {
return fmt.Errorf("no assets found")
}
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return fmt.Errorf("failed to open database: %v", err)
}
tx := db.MustBegin()
for _, state := range results {
sql := fmt.Sprintf(`DELETE FROM %s WHERE host = :host, port = :port;`, TABLE_NAME)
_, err := tx.NamedExec(sql, &state)
if err != nil {
fmt.Printf("failed to execute transaction: %v\n", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %v", err)
}
return nil
}
func GetScannedAssets(path string) ([]magellan.ScannedAsset, error) {
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return nil, fmt.Errorf("failed to open database: %v", err)
}
results := []magellan.ScannedAsset{}
err = db.Select(&results, fmt.Sprintf("SELECT * FROM %s ORDER BY host ASC, port ASC;", TABLE_NAME))
if err != nil {
return nil, fmt.Errorf("failed to retrieve assets: %v", err)
}
return results, nil
}

View file

@ -27,11 +27,9 @@ const (
HTTPS_PORT = 443 HTTPS_PORT = 443
) )
// QueryParams is a collections of common parameters passed to the CLI. // CollectParams is a collection of common parameters passed to the CLI
// Each CLI subcommand has a corresponding implementation function that // for the 'collect' subcommand.
// takes an object as an argument. However, the implementation may not type CollectParams struct {
// use all of the properties within the object.
type QueryParams struct {
Host string // set by the 'host' flag Host string // set by the 'host' flag
Port int // set by the 'port' flag Port int // set by the 'port' flag
Username string // set the BMC username with the 'username' flag Username string // set the BMC username with the 'username' flag
@ -50,7 +48,7 @@ type QueryParams struct {
// //
// Requests can be made to several of the nodes using a goroutine by setting the q.Concurrency // Requests can be made to several of the nodes using a goroutine by setting the q.Concurrency
// property value between 1 and 255. // property value between 1 and 255.
func CollectInventory(scannedResults *[]ScannedResult, params *QueryParams) error { func CollectInventory(scannedResults *[]ScannedAsset, params *CollectParams) error {
// check for available probe states // check for available probe states
if scannedResults == nil { if scannedResults == nil {
return fmt.Errorf("no probe states found") return fmt.Errorf("no probe states found")
@ -65,7 +63,7 @@ func CollectInventory(scannedResults *[]ScannedResult, params *QueryParams) erro
wg sync.WaitGroup wg sync.WaitGroup
found = make([]string, 0, len(*scannedResults)) found = make([]string, 0, len(*scannedResults))
done = make(chan struct{}, params.Concurrency+1) done = make(chan struct{}, params.Concurrency+1)
chanScannedResult = make(chan ScannedResult, params.Concurrency+1) chanScannedResult = make(chan ScannedAsset, params.Concurrency+1)
outputPath = path.Clean(params.OutputPath) outputPath = path.Clean(params.OutputPath)
smdClient = client.NewClient( smdClient = client.NewClient(
client.WithSecureTLS(params.CaCertPath), client.WithSecureTLS(params.CaCertPath),
@ -94,7 +92,7 @@ func CollectInventory(scannedResults *[]ScannedResult, params *QueryParams) erro
// TODO: use pkg/crawler to request inventory data via Redfish // TODO: use pkg/crawler to request inventory data via Redfish
systems, err := crawler.CrawlBMC(crawler.CrawlerConfig{ systems, err := crawler.CrawlBMC(crawler.CrawlerConfig{
URI: fmt.Sprintf("https://%s:%d", sr.Host, sr.Port), URI: fmt.Sprintf("%s:%d", sr.Host, sr.Port),
Username: params.Username, Username: params.Username,
Password: params.Password, Password: params.Password,
Insecure: true, Insecure: true,
@ -131,14 +129,14 @@ func CollectInventory(scannedResults *[]ScannedResult, params *QueryParams) erro
// write JSON data to file if output path is set using hive partitioning strategy // write JSON data to file if output path is set using hive partitioning strategy
if outputPath != "" { if outputPath != "" {
err = os.MkdirAll(outputPath, os.ModeDir) err = os.MkdirAll(outputPath, 0o644)
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to make directory for output") log.Error().Err(err).Msg("failed to make directory for output")
} else { } else {
// make the output directory to store files // make the output directory to store files
outputPath, err := util.MakeOutputDirectory(outputPath, false) outputPath, err := util.MakeOutputDirectory(outputPath, false)
if err != nil { if err != nil {
log.Error().Msgf("failed to make output directory: %v", err) log.Error().Err(err).Msg("failed to make output directory")
} else { } else {
// write the output to the final path // write the output to the final path
err = os.WriteFile(path.Clean(fmt.Sprintf("%s/%s/%d.json", params.Host, outputPath, time.Now().Unix())), body, os.ModePerm) err = os.WriteFile(path.Clean(fmt.Sprintf("%s/%s/%d.json", params.Host, outputPath, time.Now().Unix())), body, os.ModePerm)
@ -197,6 +195,6 @@ func CollectInventory(scannedResults *[]ScannedResult, params *QueryParams) erro
return nil return nil
} }
func baseRedfishUrl(q *QueryParams) string { func baseRedfishUrl(q *CollectParams) string {
return fmt.Sprintf("%s:%d", q.Host, q.Port) return fmt.Sprintf("%s:%d", q.Host, q.Port)
} }

View file

@ -1,95 +0,0 @@
package sqlite
import (
"fmt"
magellan "github.com/OpenCHAMI/magellan/internal"
"github.com/jmoiron/sqlx"
)
func CreateProbeResultsIfNotExists(path string) (*sqlx.DB, error) {
schema := `
CREATE TABLE IF NOT EXISTS magellan_scanned_ports (
host TEXT NOT NULL,
port INTEGER NOT NULL,
protocol TEXT,
state INTEGER,
timestamp TIMESTAMP,
PRIMARY KEY (host, port)
);
`
// TODO: it may help with debugging to check for file permissions here first
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return nil, fmt.Errorf("failed toopen database: %v", err)
}
db.MustExec(schema)
return db, nil
}
func InsertProbeResults(path string, states *[]magellan.ScannedResult) error {
if states == nil {
return fmt.Errorf("states == nil")
}
// create database if it doesn't already exist
db, err := CreateProbeResultsIfNotExists(path)
if err != nil {
return err
}
// insert all probe states into db
tx := db.MustBegin()
for _, state := range *states {
sql := `INSERT OR REPLACE INTO magellan_scanned_ports (host, port, protocol, state, timestamp)
VALUES (:host, :port, :protocol, :state, :timestamp);`
_, err := tx.NamedExec(sql, &state)
if err != nil {
fmt.Printf("failed toexecute transaction: %v\n", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed tocommit transaction: %v", err)
}
return nil
}
func DeleteProbeResults(path string, results *[]magellan.ScannedResult) error {
if results == nil {
return fmt.Errorf("no probe results found")
}
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return fmt.Errorf("failed toopen database: %v", err)
}
tx := db.MustBegin()
for _, state := range *results {
sql := `DELETE FROM magellan_scanned_ports WHERE host = :host, port = :port;`
_, err := tx.NamedExec(sql, &state)
if err != nil {
fmt.Printf("failed toexecute transaction: %v\n", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed tocommit transaction: %v", err)
}
return nil
}
func GetScannedResults(path string) ([]magellan.ScannedResult, error) {
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return nil, fmt.Errorf("failed toopen database: %v", err)
}
results := []magellan.ScannedResult{}
err = db.Select(&results, "SELECT * FROM magellan_scanned_ports ORDER BY host ASC, port ASC;")
if err != nil {
return nil, fmt.Errorf("failed toretrieve probes: %v", err)
}
return results, nil
}

View file

@ -5,13 +5,16 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"strconv"
"sync" "sync"
"time" "time"
"github.com/OpenCHAMI/magellan/internal/util" "github.com/OpenCHAMI/magellan/internal/util"
"github.com/rs/zerolog/log"
) )
type ScannedResult struct { type ScannedAsset struct {
Host string `json:"host"` Host string `json:"host"`
Port int `json:"port"` Port int `json:"port"`
Protocol string `json:"protocol"` Protocol string `json:"protocol"`
@ -19,9 +22,21 @@ type ScannedResult struct {
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
} }
// ScanParams is a collection of commom parameters passed to the CLI
type ScanParams struct {
TargetHosts [][]string
Scheme string
Protocol string
Concurrency int
Timeout int
DisableProbing bool
Verbose bool
Debug bool
}
// ScanForAssets() performs a net scan on a network to find available services // ScanForAssets() performs a net scan on a network to find available services
// running. The function expects a list of hosts and ports to make requests. // running. The function expects a list of targets (as [][]string) to make requests.
// Note that each all ports will be used per host. // The 2D list is to permit one goroutine per BMC node when making each request.
// //
// This function runs in a goroutine with the "concurrency" flag setting the // This function runs in a goroutine with the "concurrency" flag setting the
// number of concurrent requests. Only one request is made to each BMC node // number of concurrent requests. Only one request is made to each BMC node
@ -34,54 +49,67 @@ type ScannedResult struct {
// remove the service from being stored in the list of scanned results. // remove the service from being stored in the list of scanned results.
// //
// Returns a list of scanned results to be stored in cache (but isn't doing here). // Returns a list of scanned results to be stored in cache (but isn't doing here).
func ScanForAssets(hosts []string, ports []int, concurrency int, timeout int, disableProbing bool, verbose bool) []ScannedResult { func ScanForAssets(params *ScanParams) []ScannedAsset {
var ( var (
results = make([]ScannedResult, 0, len(hosts)) results = make([]ScannedAsset, 0, len(params.TargetHosts))
done = make(chan struct{}, concurrency+1) done = make(chan struct{}, params.Concurrency+1)
chanHost = make(chan string, concurrency+1) chanHosts = make(chan []string, params.Concurrency+1)
) )
if params.Verbose {
log.Info().Msg("starting scan...")
}
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(concurrency) wg.Add(params.Concurrency)
for i := 0; i < concurrency; i++ { for i := 0; i < params.Concurrency; i++ {
go func() { go func() {
for { for {
host, ok := <-chanHost hosts, ok := <-chanHosts
if !ok { if !ok {
wg.Done() wg.Done()
return return
} }
scannedResults := rawConnect(host, ports, timeout, true) for _, host := range hosts {
if !disableProbing { foundAssets, err := rawConnect(host, params.Protocol, params.Timeout, true)
probeResults := []ScannedResult{} // if we failed to connect, exit from the function
for _, result := range scannedResults { if err != nil {
url := fmt.Sprintf("https://%s:%d/redfish/v1/", result.Host, result.Port) if params.Verbose {
res, _, err := util.MakeRequest(nil, url, "GET", nil, nil) log.Debug().Err(err).Msgf("failed to connect to host (%s)", host)
if err != nil || res == nil {
if verbose {
fmt.Printf("failed to make request: %v\n", err)
}
continue
} else if res.StatusCode != http.StatusOK {
if verbose {
fmt.Printf("request returned code: %v\n", res.StatusCode)
}
continue
} else {
probeResults = append(probeResults, result)
} }
wg.Done()
return
}
if !params.DisableProbing {
assetsToAdd := []ScannedAsset{}
for _, foundAsset := range foundAssets {
url := fmt.Sprintf("%s://%s/redfish/v1/", params.Scheme, foundAsset.Host)
res, _, err := util.MakeRequest(nil, url, http.MethodGet, nil, nil)
if err != nil || res == nil {
if params.Verbose {
log.Printf("failed to make request: %v\n", err)
}
continue
} else if res.StatusCode != http.StatusOK {
if params.Verbose {
log.Printf("request returned code: %v\n", res.StatusCode)
}
continue
} else {
assetsToAdd = append(assetsToAdd, foundAsset)
}
}
results = append(results, assetsToAdd...)
} else {
results = append(results, foundAssets...)
} }
results = append(results, probeResults...)
} else {
results = append(results, scannedResults...)
} }
} }
}() }()
} }
for _, host := range hosts { for _, hosts := range params.TargetHosts {
chanHost <- host chanHosts <- hosts
} }
go func() { go func() {
select { select {
@ -92,13 +120,17 @@ func ScanForAssets(hosts []string, ports []int, concurrency int, timeout int, di
time.Sleep(1000) time.Sleep(1000)
} }
}() }()
close(chanHost) close(chanHosts)
wg.Wait() wg.Wait()
close(done) close(done)
if params.Verbose {
log.Info().Msg("scan complete")
}
return results return results
} }
// GenerateHosts() builds a list of hosts to scan using the "subnet" // GenerateHostsWithSubnet() builds a list of hosts to scan using the "subnet"
// and "subnetMask" arguments passed. The function is capable of // and "subnetMask" arguments passed. The function is capable of
// distinguishing between IP formats: a subnet with just an IP address (172.16.0.0) and // distinguishing between IP formats: a subnet with just an IP address (172.16.0.0) and
// a subnet with IP address and CIDR (172.16.0.0/24). // a subnet with IP address and CIDR (172.16.0.0/24).
@ -106,83 +138,111 @@ func ScanForAssets(hosts []string, ports []int, concurrency int, timeout int, di
// NOTE: If a IP address is provided with CIDR, then the "subnetMask" // NOTE: If a IP address is provided with CIDR, then the "subnetMask"
// parameter will be ignored. If neither is provided, then the default // parameter will be ignored. If neither is provided, then the default
// subnet mask will be used instead. // subnet mask will be used instead.
func GenerateHosts(subnet string, subnetMask *net.IP) []string { func GenerateHostsWithSubnet(subnet string, subnetMask *net.IPMask, additionalPorts []int, defaultScheme string) [][]string {
if subnet == "" || subnetMask == nil { if subnet == "" || subnetMask == nil {
return nil return nil
} }
// convert subnets from string to net.IP // convert subnets from string to net.IP to test if CIDR is included
subnetIp := net.ParseIP(subnet) subnetIp := net.ParseIP(subnet)
if subnetIp == nil { if subnetIp == nil {
// try parse CIDR instead // not a valid IP so try again with CIDR
ip, network, err := net.ParseCIDR(subnet) ip, network, err := net.ParseCIDR(subnet)
if err != nil { if err != nil {
return nil return nil
} }
subnetIp = ip subnetIp = ip
if network != nil { if network == nil {
t := net.IP(network.Mask) // use the default subnet mask if a valid one is not provided
subnetMask = &t network = &net.IPNet{
IP: subnetIp,
Mask: net.IPv4Mask(255, 255, 255, 0),
}
} }
subnetMask = &network.Mask
} }
mask := net.IPMask(subnetMask.To4()) // generate new IPs from subnet and format to full URL
subnetIps := generateIPsWithSubnet(&subnetIp, subnetMask)
// if no subnet mask, use a default 24-bit mask (for now) return util.FormatIPUrls(subnetIps, additionalPorts, defaultScheme, false)
return generateHosts(&subnetIp, &mask)
} }
// GetDefaultPorts() returns a list of default ports. The only reason to have
// this function is to add/remove ports without affecting usage.
func GetDefaultPorts() []int { func GetDefaultPorts() []int {
return []int{HTTPS_PORT} return []int{HTTPS_PORT}
} }
func rawConnect(host string, ports []int, timeout int, keepOpenOnly bool) []ScannedResult { // rawConnect() tries to connect to the host using DialTimeout() and waits
results := []ScannedResult{} // until a response is receive or if the timeout (in seconds) expires. This
for _, p := range ports { // function expects a full URL such as https://my.bmc.host:443/ to make the
result := ScannedResult{ // connection.
Host: host, func rawConnect(address string, protocol string, timeoutSeconds int, keepOpenOnly bool) ([]ScannedAsset, error) {
Port: p, uri, err := url.ParseRequestURI(address)
Protocol: "tcp", if err != nil {
return nil, fmt.Errorf("failed to split host/port: %w", err)
}
// convert port to its "proper" type
port, err := strconv.Atoi(uri.Port())
if err != nil {
return nil, fmt.Errorf("failed to convert port to integer type: %w", err)
}
var (
timeoutDuration = time.Second * time.Duration(timeoutSeconds)
assets []ScannedAsset
asset = ScannedAsset{
Host: uri.Host,
Port: port,
Protocol: protocol,
State: false, State: false,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
t := time.Second * time.Duration(timeout) )
port := fmt.Sprint(p)
conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, port), t) // try to conntect to host (expects host in format [10.0.0.0]:443)
if err != nil { target := fmt.Sprintf("[%s]:%s", uri.Hostname(), uri.Port())
result.State = false conn, err := net.DialTimeout(protocol, target, timeoutDuration)
// fmt.Println("Connecting error:", err) if err != nil {
} asset.State = false
if conn != nil { return nil, fmt.Errorf("failed to dial host: %w", err)
result.State = true }
defer conn.Close() if conn != nil {
// fmt.Println("Opened", net.JoinHostPort(host, port)) asset.State = true
} defer conn.Close()
if keepOpenOnly { }
if result.State { if keepOpenOnly {
results = append(results, result) if asset.State {
} assets = append(assets, asset)
} else {
results = append(results, result)
} }
} else {
assets = append(assets, asset)
} }
return results return assets, nil
} }
func generateHosts(ip *net.IP, mask *net.IPMask) []string { // generateIPsWithSubnet() returns a collection of host IP strings with a
// provided subnet mask.
//
// TODO: add a way for filtering/exclude specific IPs and IP ranges.
func generateIPsWithSubnet(ip *net.IP, mask *net.IPMask) []string {
// check if subnet IP and mask are valid
if ip == nil || mask == nil {
log.Error().Msg("invalid subnet IP or mask (ip == nil or mask == nil)")
return nil
}
// get all IP addresses in network // get all IP addresses in network
ones, _ := mask.Size() ones, bits := mask.Size()
hosts := []string{} hosts := []string{}
end := int(math.Pow(2, float64((32-ones)))) - 1 end := int(math.Pow(2, float64((bits-ones)))) - 1
for i := 0; i < end; i++ { for i := 0; i < end; i++ {
// ip[3] = byte(i)
ip = util.GetNextIP(ip, 1) ip = util.GetNextIP(ip, 1)
if ip == nil { if ip == nil {
continue continue
} }
// host := fmt.Sprintf("%v.%v.%v.%v", (*ip)[0], (*ip)[1], (*ip)[2], (*ip)[3])
// fmt.Printf("host: %v\n", ip.String())
hosts = append(hosts, ip.String()) hosts = append(hosts, ip.String())
} }
return hosts return hosts

View file

@ -9,7 +9,7 @@ import (
) )
type UpdateParams struct { type UpdateParams struct {
QueryParams CollectParams
FirmwarePath string FirmwarePath string
FirmwareVersion string FirmwareVersion string
Component string Component string
@ -20,7 +20,7 @@ type UpdateParams struct {
// The function expects the firmware URL, firmware version, and component flags to be // The function expects the firmware URL, firmware version, and component flags to be
// set from the CLI to perform a firmware update. // set from the CLI to perform a firmware update.
func UpdateFirmwareRemote(q *UpdateParams) error { func UpdateFirmwareRemote(q *UpdateParams) error {
url := baseRedfishUrl(&q.QueryParams) + "/redfish/v1/UpdateService/Actions/SimpleUpdate" url := baseRedfishUrl(&q.CollectParams) + "/redfish/v1/UpdateService/Actions/SimpleUpdate"
headers := map[string]string{ headers := map[string]string{
"Content-Type": "application/json", "Content-Type": "application/json",
"cache-control": "no-cache", "cache-control": "no-cache",
@ -47,7 +47,7 @@ func UpdateFirmwareRemote(q *UpdateParams) error {
} }
func GetUpdateStatus(q *UpdateParams) error { func GetUpdateStatus(q *UpdateParams) error {
url := baseRedfishUrl(&q.QueryParams) + "/redfish/v1/UpdateService" url := baseRedfishUrl(&q.CollectParams) + "/redfish/v1/UpdateService"
res, body, err := util.MakeRequest(nil, url, "GET", nil, nil) res, body, err := util.MakeRequest(nil, url, "GET", nil, nil)
if err != nil { if err != nil {
return fmt.Errorf("something went wrong: %v", err) return fmt.Errorf("something went wrong: %v", err)

View file

@ -7,6 +7,10 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url"
"strings"
"github.com/rs/zerolog/log"
) )
// HTTP aliases for readibility // HTTP aliases for readibility
@ -78,3 +82,101 @@ func MakeRequest(client *http.Client, url string, httpMethod string, body HTTPBo
} }
return res, b, err return res, b, err
} }
// FormatHostUrls() takes a list of hosts and ports and builds full URLs in the
// form of scheme://host:port. If no scheme is provided, it will use "https" by
// default.
//
// Returns a 2D string slice where each slice contains URL host strings for each
// port. The intention is to have all of the URLs for a single host combined into
// a single slice to initiate one goroutine per host, but making request to multiple
// ports.
func FormatHostUrls(hosts []string, ports []int, scheme string, verbose bool) [][]string {
// format each positional arg as a complete URL
var formattedHosts [][]string
for _, host := range hosts {
uri, err := url.ParseRequestURI(host)
if err != nil {
if verbose {
log.Warn().Msgf("invalid URI parsed: %s", host)
}
continue
}
// check if scheme is set, if not set it with flag or default value ('https' if flag is not set)
if uri.Scheme == "" {
if scheme != "" {
uri.Scheme = scheme
} else {
// hardcoded assumption
uri.Scheme = "https"
}
}
// tidy up slashes and update arg with new value
uri.Path = strings.TrimSuffix(uri.Path, "/")
uri.Path = strings.ReplaceAll(uri.Path, "//", "/")
// for hosts with unspecified ports, add ports to scan from flag
if uri.Port() == "" {
var tmp []string
for _, port := range ports {
uri.Host += fmt.Sprintf(":%d", port)
tmp = append(tmp, uri.String())
}
formattedHosts = append(formattedHosts, tmp)
} else {
formattedHosts = append(formattedHosts, []string{uri.String()})
}
}
return formattedHosts
}
// FormatIPUrls() takes a list of IP addresses and ports and builds full URLs in the
// form of scheme://host:port. If no scheme is provided, it will use "https" by
// default.
//
// Returns a 2D string slice where each slice contains URL host strings for each
// port. The intention is to have all of the URLs for a single host combined into
// a single slice to initiate one goroutine per host, but making request to multiple
// ports.
func FormatIPUrls(ips []string, ports []int, scheme string, verbose bool) [][]string {
// format each positional arg as a complete URL
var formattedHosts [][]string
for _, ip := range ips {
// if parsing completely fails, try to build new URL object
uri := &url.URL{
Scheme: scheme,
Host: ip,
}
// check if scheme is set, if not set it with flag or default value ('https' if flag is not set)
if uri.Scheme == "" {
if scheme != "" {
uri.Scheme = scheme
} else {
// hardcoded assumption
uri.Scheme = "https"
}
}
// tidy up slashes and update arg with new value
uri.Path = strings.TrimSuffix(uri.Path, "/")
uri.Path = strings.ReplaceAll(uri.Path, "//", "/")
// for hosts with unspecified ports, add ports to scan from flag
if uri.Port() == "" {
var tmp []string
for _, port := range ports {
uri.Host = fmt.Sprintf("%s:%d", ip, port)
tmp = append(tmp, uri.String())
}
formattedHosts = append(formattedHosts, tmp)
} else {
formattedHosts = append(formattedHosts, []string{uri.String()})
}
}
return formattedHosts
}

View file

@ -44,10 +44,6 @@ func SplitPathForViper(path string) (string, string, string) {
// //
// Returns the final path that was created if no errors occurred. Otherwise, // Returns the final path that was created if no errors occurred. Otherwise,
// it returns an empty string with an error. // it returns an empty string with an error.
//
// TODO: Refactor this function for hive partitioning or possibly move into
// the logging package.
// TODO: Add an option to force overwriting the path.
func MakeOutputDirectory(path string, overwrite bool) (string, error) { func MakeOutputDirectory(path string, overwrite bool) (string, error) {
// get the current data + time using Go's stupid formatting // get the current data + time using Go's stupid formatting
t := time.Now() t := time.Now()