You've already forked armbian-router
mirror of
https://github.com/armbian/armbian-router.git
synced 2026-01-06 10:37:03 -08:00
WIP: Version checks
This commit is contained in:
125
check.go
125
check.go
@@ -6,9 +6,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -28,14 +30,18 @@ var (
|
||||
ErrCertExpired = errors.New("certificate is expired")
|
||||
)
|
||||
|
||||
func (r *Redirector) checkHTTP(scheme string) ServerCheck {
|
||||
return func(server *Server, logFields log.Fields) (bool, error) {
|
||||
return r.checkHTTPScheme(server, scheme, logFields)
|
||||
}
|
||||
// HTTPCheck is a check for validity and redirects
|
||||
type HTTPCheck struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
// checkHTTPScheme checks a URL for validity, and checks redirects
|
||||
func (r *Redirector) checkHTTPScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
|
||||
// Check checks a URL for validity, and checks redirects
|
||||
func (h *HTTPCheck) Check(server *Server, logFields log.Fields) (bool, error) {
|
||||
return h.checkHTTPScheme(server, "http", logFields)
|
||||
}
|
||||
|
||||
// checkHTTPScheme will check if a scheme is valid and doesn't redirect
|
||||
func (h *HTTPCheck) checkHTTPScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
|
||||
u := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: server.Host,
|
||||
@@ -50,7 +56,7 @@ func (r *Redirector) checkHTTPScheme(server *Server, scheme string, logFields lo
|
||||
return false, err
|
||||
}
|
||||
|
||||
res, err := r.config.checkClient.Do(req)
|
||||
res, err := h.config.checkClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -66,18 +72,18 @@ func (r *Redirector) checkHTTPScheme(server *Server, scheme string, logFields lo
|
||||
|
||||
switch u.Scheme {
|
||||
case "http":
|
||||
res, err := r.checkRedirect(u.Scheme, location)
|
||||
res, err := h.checkRedirect(u.Scheme, location)
|
||||
|
||||
if !res || err != nil {
|
||||
// If we don't support http, we remove it from supported protocols
|
||||
server.Protocols = server.Protocols.Remove("http")
|
||||
} else {
|
||||
// Otherwise, we verify https support
|
||||
r.checkProtocol(server, "https")
|
||||
h.checkProtocol(server, "https")
|
||||
}
|
||||
case "https":
|
||||
// We don't want to allow downgrading, so this is an error.
|
||||
return r.checkRedirect(u.Scheme, location)
|
||||
return h.checkRedirect(u.Scheme, location)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,8 +95,8 @@ func (r *Redirector) checkHTTPScheme(server *Server, scheme string, logFields lo
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *Redirector) checkProtocol(server *Server, scheme string) {
|
||||
res, err := r.checkHTTPScheme(server, scheme, log.Fields{})
|
||||
func (h *HTTPCheck) checkProtocol(server *Server, scheme string) {
|
||||
res, err := h.checkHTTPScheme(server, scheme, log.Fields{})
|
||||
|
||||
if !res || err != nil {
|
||||
return
|
||||
@@ -102,7 +108,7 @@ func (r *Redirector) checkProtocol(server *Server, scheme string) {
|
||||
}
|
||||
|
||||
// checkRedirect parses a location header response and checks the scheme
|
||||
func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bool, error) {
|
||||
func (h *HTTPCheck) checkRedirect(originatingScheme, locationHeader string) (bool, error) {
|
||||
newURL, err := url.Parse(locationHeader)
|
||||
|
||||
if err != nil {
|
||||
@@ -118,8 +124,13 @@ func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bo
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// checkTLS checks tls certificates from a host, ensures they're valid, and not expired.
|
||||
func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error) {
|
||||
// TLSCheck is a TLS certificate check
|
||||
type TLSCheck struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
// Check checks tls certificates from a host, ensures they're valid, and not expired.
|
||||
func (t *TLSCheck) Check(server *Server, logFields log.Fields) (bool, error) {
|
||||
var host, port string
|
||||
var err error
|
||||
|
||||
@@ -144,7 +155,7 @@ func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", host+":"+port, &tls.Config{
|
||||
RootCAs: r.config.RootCAs,
|
||||
RootCAs: t.config.RootCAs,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -174,7 +185,7 @@ func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: r.config.RootCAs,
|
||||
Roots: t.config.RootCAs,
|
||||
Intermediates: peerPool,
|
||||
CurrentTime: time.Now(),
|
||||
}
|
||||
@@ -213,3 +224,83 @@ func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type VersionCheck struct {
|
||||
config *Config
|
||||
VersionURL string
|
||||
lastVersion string
|
||||
lastVersionTime time.Time
|
||||
}
|
||||
|
||||
func (v *VersionCheck) getCurrentVersion() (string, error) {
|
||||
if v.lastVersion != "" && time.Now().Before(v.lastVersionTime.Add(5*time.Minute)) {
|
||||
return v.lastVersion, nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, v.VersionURL, nil)
|
||||
|
||||
req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")")
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
res, err := v.config.checkClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(io.LimitReader(res.Body, 128))
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
v.lastVersion = string(b)
|
||||
v.lastVersionTime = time.Now()
|
||||
|
||||
return v.lastVersion, nil
|
||||
}
|
||||
|
||||
func (v *VersionCheck) Check(server *Server, logFields log.Fields) (bool, error) {
|
||||
currentVersion, err := v.getCurrentVersion()
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
controlPath := path.Join(server.Path, ".control")
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: server.Host,
|
||||
Path: controlPath,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
|
||||
req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")")
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
res, err := v.config.checkClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(io.LimitReader(res.Body, 128))
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return string(b) == currentVersion, nil
|
||||
}
|
||||
|
||||
@@ -64,7 +64,9 @@ var _ = Describe("Check suite", func() {
|
||||
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(w, r)
|
||||
}))
|
||||
r = New(&Config{})
|
||||
r = New(&Config{
|
||||
checkClient: &http.Client{},
|
||||
})
|
||||
r.config.SetRootCAs(x509.NewCertPool())
|
||||
})
|
||||
AfterEach(func() {
|
||||
@@ -83,16 +85,18 @@ var _ = Describe("Check suite", func() {
|
||||
}
|
||||
|
||||
Context("HTTP Checks", func() {
|
||||
var h *HTTPCheck
|
||||
BeforeEach(func() {
|
||||
httpServer.Start()
|
||||
setupServer()
|
||||
h = &HTTPCheck{config: r.config}
|
||||
})
|
||||
It("Should successfully check for connectivity", func() {
|
||||
handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
res, err := r.checkHTTPScheme(server, "http", log.Fields{})
|
||||
res, err := h.checkHTTPScheme(server, "http", log.Fields{})
|
||||
|
||||
Expect(res).To(BeTrue())
|
||||
Expect(err).To(BeNil())
|
||||
@@ -101,6 +105,7 @@ var _ = Describe("Check suite", func() {
|
||||
Context("TLS Checks", func() {
|
||||
var (
|
||||
x509Cert *x509.Certificate
|
||||
t *TLSCheck
|
||||
)
|
||||
setupCerts := func(notBefore, notAfter time.Time) {
|
||||
cert, key, err := genTestCerts(notBefore, notAfter)
|
||||
@@ -131,10 +136,13 @@ var _ = Describe("Check suite", func() {
|
||||
|
||||
r.config.SetRootCAs(pool)
|
||||
|
||||
t = &TLSCheck{config: r.config}
|
||||
|
||||
httpServer.StartTLS()
|
||||
setupServer()
|
||||
}
|
||||
Context("HTTPS Checks", func() {
|
||||
h := &HTTPCheck{config: r.config}
|
||||
BeforeEach(func() {
|
||||
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
|
||||
})
|
||||
@@ -146,7 +154,7 @@ var _ = Describe("Check suite", func() {
|
||||
|
||||
logFields := log.Fields{}
|
||||
|
||||
res, err := r.checkHTTPScheme(server, "https", logFields)
|
||||
res, err := h.checkHTTPScheme(server, "https", logFields)
|
||||
|
||||
Expect(logFields["url"]).ToNot(BeEmpty())
|
||||
Expect(logFields["url"]).ToNot(Equal(httpServer.URL))
|
||||
@@ -161,13 +169,13 @@ var _ = Describe("Check suite", func() {
|
||||
It("Should fail due to invalid ca", func() {
|
||||
r.config.SetRootCAs(x509.NewCertPool())
|
||||
|
||||
res, err := r.checkTLS(server, log.Fields{})
|
||||
res, err := t.Check(server, log.Fields{})
|
||||
|
||||
Expect(res).To(BeFalse())
|
||||
Expect(err).ToNot(BeNil())
|
||||
})
|
||||
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
|
||||
res, err := r.checkTLS(server, log.Fields{})
|
||||
res, err := t.Check(server, log.Fields{})
|
||||
|
||||
Expect(res).To(BeFalse())
|
||||
Expect(err).ToNot(BeNil())
|
||||
@@ -178,7 +186,7 @@ var _ = Describe("Check suite", func() {
|
||||
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
|
||||
|
||||
// Check TLS
|
||||
res, err := r.checkTLS(server, log.Fields{})
|
||||
res, err := t.Check(server, log.Fields{})
|
||||
|
||||
Expect(res).To(BeFalse())
|
||||
Expect(err).ToNot(BeNil())
|
||||
@@ -187,7 +195,34 @@ var _ = Describe("Check suite", func() {
|
||||
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
|
||||
|
||||
// Check TLS
|
||||
res, err := r.checkTLS(server, log.Fields{})
|
||||
res, err := t.Check(server, log.Fields{})
|
||||
|
||||
Expect(res).To(BeFalse())
|
||||
Expect(err).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
Context("Version checks", func() {
|
||||
v := &VersionCheck{
|
||||
config: r.config,
|
||||
lastVersion: "1234567890",
|
||||
lastVersionTime: time.Now(),
|
||||
}
|
||||
It("Should succeed and match versions", func() {
|
||||
handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("1234567890"))
|
||||
}
|
||||
|
||||
res, err := v.Check(server, log.Fields{})
|
||||
|
||||
Expect(res).To(BeTrue())
|
||||
Expect(err).To(BeNil())
|
||||
})
|
||||
It("Should fail due to mismatched versions", func() {
|
||||
handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("0987654321"))
|
||||
}
|
||||
|
||||
res, err := v.Check(server, log.Fields{})
|
||||
|
||||
Expect(res).To(BeFalse())
|
||||
Expect(err).ToNot(BeNil())
|
||||
|
||||
@@ -41,6 +41,9 @@ type Config struct {
|
||||
// ReloadToken is a secret token used for web-based reload.
|
||||
ReloadToken string `mapstructure:"reloadToken"`
|
||||
|
||||
// CheckURL is the url used to verify mirror versions
|
||||
CheckURL string `mapstructure:"checkUrl"`
|
||||
|
||||
// ServerList is a list of ServerConfig structs, which gets parsed into servers.
|
||||
ServerList []ServerConfig `mapstructure:"servers"`
|
||||
|
||||
|
||||
@@ -83,12 +83,14 @@ type ASN struct {
|
||||
// ServerConfig is a configuration struct holding basic server configuration.
|
||||
// This is used for initial loading of server information before parsing into Server.
|
||||
type ServerConfig struct {
|
||||
Server string `mapstructure:"server" yaml:"server"`
|
||||
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
|
||||
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
|
||||
Continent string `mapstructure:"continent"`
|
||||
Weight int `mapstructure:"weight" yaml:"weight"`
|
||||
Protocols []string `mapstructure:"protocols" yaml:"protocols"`
|
||||
Server string `mapstructure:"server" yaml:"server"`
|
||||
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
|
||||
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
|
||||
Continent string `mapstructure:"continent"`
|
||||
Weight int `mapstructure:"weight" yaml:"weight"`
|
||||
Protocols []string `mapstructure:"protocols" yaml:"protocols"`
|
||||
IncludeASN []ASNList `mapstructure:"includeASN" yaml:"includeASN"`
|
||||
ExcludeASN []ASNList `mapstructure:"excludeASN" yaml:"excludeASN"`
|
||||
}
|
||||
|
||||
// New creates a new instance of Redirector
|
||||
@@ -98,8 +100,18 @@ func New(config *Config) *Redirector {
|
||||
}
|
||||
|
||||
r.checks = []ServerCheck{
|
||||
r.checkHTTP("http"),
|
||||
r.checkTLS,
|
||||
&HTTPCheck{
|
||||
config: config,
|
||||
},
|
||||
&TLSCheck{
|
||||
config: config,
|
||||
},
|
||||
}
|
||||
|
||||
if config.CheckURL != "" {
|
||||
r.checks = append(r.checks, &VersionCheck{
|
||||
VersionURL: config.CheckURL,
|
||||
})
|
||||
}
|
||||
|
||||
return r
|
||||
|
||||
@@ -28,7 +28,9 @@ type Server struct {
|
||||
}
|
||||
|
||||
// ServerCheck is a check function which can return information about a status.
|
||||
type ServerCheck func(server *Server, logFields log.Fields) (bool, error)
|
||||
type ServerCheck interface {
|
||||
Check(server *Server, logFields log.Fields) (bool, error)
|
||||
}
|
||||
|
||||
// checkStatus runs all status checks against a server
|
||||
func (server *Server) checkStatus(checks []ServerCheck) {
|
||||
@@ -40,7 +42,7 @@ func (server *Server) checkStatus(checks []ServerCheck) {
|
||||
var err error
|
||||
|
||||
for _, check := range checks {
|
||||
res, err = check(server, logFields)
|
||||
res, err = check.Check(server, logFields)
|
||||
|
||||
if err != nil {
|
||||
logFields["error"] = err
|
||||
|
||||
Reference in New Issue
Block a user