WIP: Version checks

This commit is contained in:
Tyler
2023-01-07 10:15:08 -05:00
parent 1bbd2920b5
commit ee2b80a7df
5 changed files with 177 additions and 34 deletions

125
check.go
View File

@@ -6,9 +6,11 @@ import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"io"
"net"
"net/http"
"net/url"
"path"
"runtime"
"strings"
"time"
@@ -28,14 +30,18 @@ var (
ErrCertExpired = errors.New("certificate is expired")
)
func (r *Redirector) checkHTTP(scheme string) ServerCheck {
return func(server *Server, logFields log.Fields) (bool, error) {
return r.checkHTTPScheme(server, scheme, logFields)
}
// HTTPCheck is a check for validity and redirects
type HTTPCheck struct {
config *Config
}
// checkHTTPScheme checks a URL for validity, and checks redirects
func (r *Redirector) checkHTTPScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
// Check checks a URL for validity, and checks redirects
func (h *HTTPCheck) Check(server *Server, logFields log.Fields) (bool, error) {
return h.checkHTTPScheme(server, "http", logFields)
}
// checkHTTPScheme will check if a scheme is valid and doesn't redirect
func (h *HTTPCheck) checkHTTPScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
u := &url.URL{
Scheme: scheme,
Host: server.Host,
@@ -50,7 +56,7 @@ func (r *Redirector) checkHTTPScheme(server *Server, scheme string, logFields lo
return false, err
}
res, err := r.config.checkClient.Do(req)
res, err := h.config.checkClient.Do(req)
if err != nil {
return false, err
@@ -66,18 +72,18 @@ func (r *Redirector) checkHTTPScheme(server *Server, scheme string, logFields lo
switch u.Scheme {
case "http":
res, err := r.checkRedirect(u.Scheme, location)
res, err := h.checkRedirect(u.Scheme, location)
if !res || err != nil {
// If we don't support http, we remove it from supported protocols
server.Protocols = server.Protocols.Remove("http")
} else {
// Otherwise, we verify https support
r.checkProtocol(server, "https")
h.checkProtocol(server, "https")
}
case "https":
// We don't want to allow downgrading, so this is an error.
return r.checkRedirect(u.Scheme, location)
return h.checkRedirect(u.Scheme, location)
}
}
@@ -89,8 +95,8 @@ func (r *Redirector) checkHTTPScheme(server *Server, scheme string, logFields lo
return false, nil
}
func (r *Redirector) checkProtocol(server *Server, scheme string) {
res, err := r.checkHTTPScheme(server, scheme, log.Fields{})
func (h *HTTPCheck) checkProtocol(server *Server, scheme string) {
res, err := h.checkHTTPScheme(server, scheme, log.Fields{})
if !res || err != nil {
return
@@ -102,7 +108,7 @@ func (r *Redirector) checkProtocol(server *Server, scheme string) {
}
// checkRedirect parses a location header response and checks the scheme
func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bool, error) {
func (h *HTTPCheck) checkRedirect(originatingScheme, locationHeader string) (bool, error) {
newURL, err := url.Parse(locationHeader)
if err != nil {
@@ -118,8 +124,13 @@ func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bo
return true, nil
}
// checkTLS checks tls certificates from a host, ensures they're valid, and not expired.
func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error) {
// TLSCheck is a TLS certificate check
type TLSCheck struct {
config *Config
}
// Check checks tls certificates from a host, ensures they're valid, and not expired.
func (t *TLSCheck) Check(server *Server, logFields log.Fields) (bool, error) {
var host, port string
var err error
@@ -144,7 +155,7 @@ func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error
}
conn, err := tls.Dial("tcp", host+":"+port, &tls.Config{
RootCAs: r.config.RootCAs,
RootCAs: t.config.RootCAs,
})
if err != nil {
@@ -174,7 +185,7 @@ func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error
}
opts := x509.VerifyOptions{
Roots: r.config.RootCAs,
Roots: t.config.RootCAs,
Intermediates: peerPool,
CurrentTime: time.Now(),
}
@@ -213,3 +224,83 @@ func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error
return true, nil
}
type VersionCheck struct {
config *Config
VersionURL string
lastVersion string
lastVersionTime time.Time
}
func (v *VersionCheck) getCurrentVersion() (string, error) {
if v.lastVersion != "" && time.Now().Before(v.lastVersionTime.Add(5*time.Minute)) {
return v.lastVersion, nil
}
req, err := http.NewRequest(http.MethodGet, v.VersionURL, nil)
req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")")
if err != nil {
return "", err
}
res, err := v.config.checkClient.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
b, err := io.ReadAll(io.LimitReader(res.Body, 128))
if err != nil {
return "", err
}
v.lastVersion = string(b)
v.lastVersionTime = time.Now()
return v.lastVersion, nil
}
func (v *VersionCheck) Check(server *Server, logFields log.Fields) (bool, error) {
currentVersion, err := v.getCurrentVersion()
if err != nil {
return false, err
}
controlPath := path.Join(server.Path, ".control")
u := &url.URL{
Scheme: "https",
Host: server.Host,
Path: controlPath,
}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")")
if err != nil {
return false, err
}
res, err := v.config.checkClient.Do(req)
if err != nil {
return false, err
}
defer res.Body.Close()
b, err := io.ReadAll(io.LimitReader(res.Body, 128))
if err != nil {
return false, err
}
return string(b) == currentVersion, nil
}

View File

@@ -64,7 +64,9 @@ var _ = Describe("Check suite", func() {
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler(w, r)
}))
r = New(&Config{})
r = New(&Config{
checkClient: &http.Client{},
})
r.config.SetRootCAs(x509.NewCertPool())
})
AfterEach(func() {
@@ -83,16 +85,18 @@ var _ = Describe("Check suite", func() {
}
Context("HTTP Checks", func() {
var h *HTTPCheck
BeforeEach(func() {
httpServer.Start()
setupServer()
h = &HTTPCheck{config: r.config}
})
It("Should successfully check for connectivity", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
res, err := r.checkHTTPScheme(server, "http", log.Fields{})
res, err := h.checkHTTPScheme(server, "http", log.Fields{})
Expect(res).To(BeTrue())
Expect(err).To(BeNil())
@@ -101,6 +105,7 @@ var _ = Describe("Check suite", func() {
Context("TLS Checks", func() {
var (
x509Cert *x509.Certificate
t *TLSCheck
)
setupCerts := func(notBefore, notAfter time.Time) {
cert, key, err := genTestCerts(notBefore, notAfter)
@@ -131,10 +136,13 @@ var _ = Describe("Check suite", func() {
r.config.SetRootCAs(pool)
t = &TLSCheck{config: r.config}
httpServer.StartTLS()
setupServer()
}
Context("HTTPS Checks", func() {
h := &HTTPCheck{config: r.config}
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
@@ -146,7 +154,7 @@ var _ = Describe("Check suite", func() {
logFields := log.Fields{}
res, err := r.checkHTTPScheme(server, "https", logFields)
res, err := h.checkHTTPScheme(server, "https", logFields)
Expect(logFields["url"]).ToNot(BeEmpty())
Expect(logFields["url"]).ToNot(Equal(httpServer.URL))
@@ -161,13 +169,13 @@ var _ = Describe("Check suite", func() {
It("Should fail due to invalid ca", func() {
r.config.SetRootCAs(x509.NewCertPool())
res, err := r.checkTLS(server, log.Fields{})
res, err := t.Check(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
res, err := r.checkTLS(server, log.Fields{})
res, err := t.Check(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
@@ -178,7 +186,7 @@ var _ = Describe("Check suite", func() {
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
// Check TLS
res, err := r.checkTLS(server, log.Fields{})
res, err := t.Check(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
@@ -187,7 +195,34 @@ var _ = Describe("Check suite", func() {
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
// Check TLS
res, err := r.checkTLS(server, log.Fields{})
res, err := t.Check(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
})
Context("Version checks", func() {
v := &VersionCheck{
config: r.config,
lastVersion: "1234567890",
lastVersionTime: time.Now(),
}
It("Should succeed and match versions", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("1234567890"))
}
res, err := v.Check(server, log.Fields{})
Expect(res).To(BeTrue())
Expect(err).To(BeNil())
})
It("Should fail due to mismatched versions", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("0987654321"))
}
res, err := v.Check(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())

View File

@@ -41,6 +41,9 @@ type Config struct {
// ReloadToken is a secret token used for web-based reload.
ReloadToken string `mapstructure:"reloadToken"`
// CheckURL is the url used to verify mirror versions
CheckURL string `mapstructure:"checkUrl"`
// ServerList is a list of ServerConfig structs, which gets parsed into servers.
ServerList []ServerConfig `mapstructure:"servers"`

View File

@@ -83,12 +83,14 @@ type ASN struct {
// ServerConfig is a configuration struct holding basic server configuration.
// This is used for initial loading of server information before parsing into Server.
type ServerConfig struct {
Server string `mapstructure:"server" yaml:"server"`
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
Continent string `mapstructure:"continent"`
Weight int `mapstructure:"weight" yaml:"weight"`
Protocols []string `mapstructure:"protocols" yaml:"protocols"`
Server string `mapstructure:"server" yaml:"server"`
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
Continent string `mapstructure:"continent"`
Weight int `mapstructure:"weight" yaml:"weight"`
Protocols []string `mapstructure:"protocols" yaml:"protocols"`
IncludeASN []ASNList `mapstructure:"includeASN" yaml:"includeASN"`
ExcludeASN []ASNList `mapstructure:"excludeASN" yaml:"excludeASN"`
}
// New creates a new instance of Redirector
@@ -98,8 +100,18 @@ func New(config *Config) *Redirector {
}
r.checks = []ServerCheck{
r.checkHTTP("http"),
r.checkTLS,
&HTTPCheck{
config: config,
},
&TLSCheck{
config: config,
},
}
if config.CheckURL != "" {
r.checks = append(r.checks, &VersionCheck{
VersionURL: config.CheckURL,
})
}
return r

View File

@@ -28,7 +28,9 @@ type Server struct {
}
// ServerCheck is a check function which can return information about a status.
type ServerCheck func(server *Server, logFields log.Fields) (bool, error)
type ServerCheck interface {
Check(server *Server, logFields log.Fields) (bool, error)
}
// checkStatus runs all status checks against a server
func (server *Server) checkStatus(checks []ServerCheck) {
@@ -40,7 +42,7 @@ func (server *Server) checkStatus(checks []ServerCheck) {
var err error
for _, check := range checks {
res, err = check(server, logFields)
res, err = check.Check(server, logFields)
if err != nil {
logFields["error"] = err