server: add support for rate limiting

This commit is contained in:
Muhammed Efe Cetin
2025-03-16 18:22:38 +03:00
committed by M. Efe Çetin
parent 4766ac3370
commit 5331e14954
5 changed files with 129 additions and 69 deletions

View File

@@ -13,7 +13,8 @@ import (
"github.com/armbian/ansi-hastebin/internal/keygenerator"
"github.com/armbian/ansi-hastebin/internal/server"
"github.com/armbian/ansi-hastebin/internal/storage"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func handleConfig(location string) (*config.Config, storage.Storage, keygenerator.KeyGenerator) {
@@ -35,7 +36,7 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato
case "s3":
pasteStorage = storage.NewS3Storage(cfg.Storage.Host, cfg.Storage.Port, cfg.Storage.Username, cfg.Storage.Password, cfg.Storage.AWSRegion, cfg.Storage.Bucket)
default:
logrus.Fatalf("Unknown storage type: %s", cfg.Storage.Type)
log.Fatal().Str("storage_type", cfg.Storage.Type).Msg("Unknown storage type")
return nil, nil, nil
}
@@ -43,17 +44,17 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato
for _, doc := range cfg.Documents {
file, err := os.OpenFile(doc.Path, os.O_RDONLY, 0644)
if err != nil {
logrus.WithError(err).WithField("path", doc.Path).Fatal("Failed to open document")
log.Fatal().Err(err).Str("path", doc.Path).Msg("Failed to open document")
}
content, err := io.ReadAll(file)
if err != nil {
logrus.WithError(err).WithField("path", doc.Path).Fatal("Failed to read document")
log.Fatal().Err(err).Str("path", doc.Path).Msg("Failed to read document")
}
file.Close()
if err := pasteStorage.Set(doc.Key, string(content), false); err != nil {
logrus.WithError(err).WithField("key", doc.Key).Fatal("Failed to set document")
log.Fatal().Err(err).Str("key", doc.Key).Msg("Failed to set document")
}
}
@@ -65,10 +66,21 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato
case "phonetic":
keyGenerator = keygenerator.NewPhoneticKeyGenerator()
default:
logrus.Fatalf("Unknown key generator: %s", cfg.KeyGenerator)
log.Fatal().Str("key_generator", cfg.KeyGenerator).Msg("Unknown key generator")
return nil, nil, nil
}
// Adjust logger
logLevel, err := zerolog.ParseLevel(cfg.Logging.Level)
if err != nil {
log.Fatal().Err(err).Str("level", cfg.Logging.Level).Msg("Failed to parse log level")
}
log.Logger = log.Level(logLevel)
if cfg.Logging.Colorize {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout})
}
return cfg, pasteStorage, keyGenerator
}

View File

@@ -15,6 +15,11 @@ documents:
- key: "about"
path: "./about.md"
rate_limiting:
enable: true
limit: 500
window: 15
logging:
level: "info"
type: "text"

View File

@@ -5,14 +5,27 @@ import (
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)
type LoggingConfig struct {
Level string `yaml:"level"`
Type string `yaml:"type"`
Colorize bool `yaml:"colorize"`
// Level is the logging level to use
Level string `yaml:"level"`
// Colorize is a flag to enable colorized output
Colorize bool `yaml:"colorize"`
}
type RateLimitingConfig struct {
// Enable is a flag to enable rate limiting
Enable bool `yaml:"enable"`
// Limit is the maximum number of requests
Limit int `yaml:"limit"`
// Window is the time window to limit requests
Window int `yaml:"window"`
}
type StorageConfig struct {
@@ -94,6 +107,9 @@ type Config struct {
// Logging is the logging configuration
Logging LoggingConfig `yaml:"logging"`
// RateLimiting is the rate limiting configuration
RateLimiting RateLimitingConfig `yaml:"rate_limiting"`
// Documents is the list of documents to load statically
Documents []DocumentConfig `yaml:"documents"`
}
@@ -112,7 +128,6 @@ var DefaultConfig = &Config{
},
Logging: LoggingConfig{
Level: "info",
Type: "text",
},
Documents: []DocumentConfig{
{
@@ -129,12 +144,12 @@ func NewConfig(configFile string) *Config {
// Read the configuration file
data, err := os.ReadFile(configFile)
if err != nil && !os.IsNotExist(err) {
logrus.WithError(err).Fatal("Failed to read configuration file")
log.Fatal().Err(err).Msg("Failed to read configuration file")
}
// Unmarshal the configuration file
if err := yaml.Unmarshal(data, cfg); err != nil {
logrus.WithError(err).Fatal("Failed to unmarshal configuration file")
log.Fatal().Err(err).Msg("Failed to unmarshal configuration file")
}
// Override with environment variables
@@ -145,7 +160,7 @@ func NewConfig(configFile string) *Config {
if port := os.Getenv("PORT"); port != "" {
portInt, err := strconv.Atoi(port)
if err != nil {
logrus.WithError(err).Fatal("Failed to parse PORT environment variable")
log.Fatal().Err(err).Msg("Failed to parse PORT environment variable")
}
cfg.Port = portInt
}
@@ -153,7 +168,7 @@ func NewConfig(configFile string) *Config {
if keyLength := os.Getenv("KEY_LENGTH"); keyLength != "" {
keyLengthInt, err := strconv.Atoi(keyLength)
if err != nil {
logrus.WithError(err).Fatal("Failed to parse KEY_LENGTH environment variable")
log.Fatal().Err(err).Msg("Failed to parse KEY_LENGTH environment variable")
}
cfg.KeyLength = keyLengthInt
}
@@ -161,7 +176,7 @@ func NewConfig(configFile string) *Config {
if maxLength := os.Getenv("MAX_LENGTH"); maxLength != "" {
maxLengthInt, err := strconv.Atoi(maxLength)
if err != nil {
logrus.WithError(err).Fatal("Failed to parse MAX_LENGTH environment variable")
log.Fatal().Err(err).Msg("Failed to parse MAX_LENGTH environment variable")
}
cfg.MaxLength = maxLengthInt
}
@@ -169,7 +184,7 @@ func NewConfig(configFile string) *Config {
if staticMaxAge := os.Getenv("STATIC_MAX_AGE"); staticMaxAge != "" {
staticMaxAgeInt, err := strconv.Atoi(staticMaxAge)
if err != nil {
logrus.WithError(err).Fatal("Failed to parse STATIC_MAX_AGE environment variable")
log.Fatal().Err(err).Msg("Failed to parse STATIC_MAX_AGE environment variable")
}
cfg.StaticMaxAge = staticMaxAgeInt
}
@@ -177,7 +192,7 @@ func NewConfig(configFile string) *Config {
if recompressStaticAssets := os.Getenv("RECOMPRESS_STATIC_ASSETS"); recompressStaticAssets != "" {
recompressStaticAssetsBool, err := strconv.ParseBool(recompressStaticAssets)
if err != nil {
logrus.WithError(err).Fatal("Failed to parse RECOMPRESS_STATIC_ASSETS environment variable")
log.Fatal().Err(err).Msg("Failed to parse RECOMPRESS_STATIC_ASSETS environment variable")
}
cfg.RecompressStaticAssets = recompressStaticAssetsBool
}
@@ -197,7 +212,7 @@ func NewConfig(configFile string) *Config {
if storagePort := os.Getenv("STORAGE_PORT"); storagePort != "" {
storagePortInt, err := strconv.Atoi(storagePort)
if err != nil {
logrus.WithError(err).Fatal("Failed to parse STORAGE_PORT environment variable")
log.Fatal().Err(err).Msg("Failed to parse STORAGE_PORT environment variable")
}
cfg.Storage.Port = storagePortInt
}
@@ -230,18 +245,41 @@ func NewConfig(configFile string) *Config {
cfg.Logging.Level = loggingLevel
}
if loggingType := os.Getenv("LOGGING_TYPE"); loggingType != "" {
cfg.Logging.Type = loggingType
}
if loggingColorize := os.Getenv("LOGGING_COLORIZE"); loggingColorize != "" {
loggingColorizeBool, err := strconv.ParseBool(loggingColorize)
if err != nil {
logrus.WithError(err).Fatal("Failed to parse LOGGING_COLORIZE environment variable")
log.Fatal().Err(err).Msg("Failed to parse LOGGING_COLORIZE environment variable")
}
cfg.Logging.Colorize = loggingColorizeBool
}
if rateLimitingEnable := os.Getenv("RATE_LIMITING_ENABLE"); rateLimitingEnable != "" {
rateLimitingEnableBool, err := strconv.ParseBool(rateLimitingEnable)
if err != nil {
log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_ENABLE environment variable")
}
cfg.RateLimiting.Enable = rateLimitingEnableBool
}
if rateLimitingLimit := os.Getenv("RATE_LIMITING_LIMIT"); rateLimitingLimit != "" {
rateLimitingLimitInt, err := strconv.Atoi(rateLimitingLimit)
if err != nil {
log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_LIMIT environment variable")
}
cfg.RateLimiting.Limit = rateLimitingLimitInt
}
if rateLimitingWindow := os.Getenv("RATE_LIMITING_WINDOW"); rateLimitingWindow != "" {
rateLimitingWindowInt, err := strconv.Atoi(rateLimitingWindow)
if err != nil {
log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_WINDOW environment variable")
}
cfg.RateLimiting.Window = rateLimitingWindowInt
}
// Walk environment variables for documents
for _, env := range os.Environ() {
if len(env) > 10 && env[:10] == "DOCUMENTS_" {
@@ -292,9 +330,5 @@ func NewConfig(configFile string) *Config {
cfg.Logging.Level = DefaultConfig.Logging.Level
}
if cfg.Logging.Type == "" {
cfg.Logging.Type = DefaultConfig.Logging.Type
}
return cfg
}

View File

@@ -4,44 +4,47 @@ import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewConfig_DefaultValues(t *testing.T) {
cfg := NewConfig("nonexistent.yaml") // Should use defaults since file doesn't exist
assert.Equal(t, "0.0.0.0", cfg.Host)
assert.Equal(t, 7777, cfg.Port)
assert.Equal(t, 10, cfg.KeyLength)
assert.Equal(t, "phonetic", cfg.KeyGenerator)
assert.Equal(t, "file", cfg.Storage.Type)
assert.Equal(t, "data", cfg.Storage.FilePath)
assert.Equal(t, "info", cfg.Logging.Level)
assert.Equal(t, "text", cfg.Logging.Type)
require.Equal(t, "0.0.0.0", cfg.Host)
require.Equal(t, 7777, cfg.Port)
require.Equal(t, 10, cfg.KeyLength)
require.Equal(t, "phonetic", cfg.KeyGenerator)
require.Equal(t, "file", cfg.Storage.Type)
require.Equal(t, "data", cfg.Storage.FilePath)
require.Equal(t, "info", cfg.Logging.Level)
}
func TestNewConfig_OverrideWithEnvVars(t *testing.T) {
os.Setenv("HOST", "127.0.0.1")
os.Setenv("PORT", "8080")
os.Setenv("KEY_LENGTH", "15")
os.Setenv("MAX_LENGTH", "5000000")
os.Setenv("STORAGE_TYPE", "redis")
os.Setenv("STORAGE_HOST", "localhost")
os.Setenv("STORAGE_PORT", "6379")
os.Setenv("LOGGING_LEVEL", "debug")
t.Setenv("HOST", "127.0.0.1")
t.Setenv("PORT", "8080")
t.Setenv("KEY_LENGTH", "15")
t.Setenv("MAX_LENGTH", "5000000")
t.Setenv("STORAGE_TYPE", "redis")
t.Setenv("STORAGE_HOST", "localhost")
t.Setenv("STORAGE_PORT", "6379")
t.Setenv("LOGGING_LEVEL", "debug")
t.Setenv("RATE_LIMITING_ENABLE", "true")
t.Setenv("RATE_LIMITING_LIMIT", "100")
defer os.Clearenv()
cfg := NewConfig("nonexistent.yaml") // Load with environment variables
assert.Equal(t, "127.0.0.1", cfg.Host)
assert.Equal(t, 8080, cfg.Port)
assert.Equal(t, 15, cfg.KeyLength)
assert.Equal(t, 5000000, cfg.MaxLength)
assert.Equal(t, "redis", cfg.Storage.Type)
assert.Equal(t, "localhost", cfg.Storage.Host)
assert.Equal(t, 6379, cfg.Storage.Port)
assert.Equal(t, "debug", cfg.Logging.Level)
require.Equal(t, "127.0.0.1", cfg.Host)
require.Equal(t, 8080, cfg.Port)
require.Equal(t, 15, cfg.KeyLength)
require.Equal(t, 5000000, cfg.MaxLength)
require.Equal(t, "redis", cfg.Storage.Type)
require.Equal(t, "localhost", cfg.Storage.Host)
require.Equal(t, 6379, cfg.Storage.Port)
require.Equal(t, "debug", cfg.Logging.Level)
require.Equal(t, true, cfg.RateLimiting.Enable)
require.Equal(t, 100, cfg.RateLimiting.Limit)
}
func TestNewConfig_LoadFromYAML(t *testing.T) {
@@ -60,21 +63,20 @@ logging:
// Write to a temporary file
tmpFile, err := os.CreateTemp("", "config_test_*.yaml")
assert.NoError(t, err)
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
_, err = tmpFile.Write([]byte(yamlContent))
assert.NoError(t, err)
require.NoError(t, err)
tmpFile.Close()
cfg := NewConfig(tmpFile.Name()) // Load config from YAML file
assert.Equal(t, "192.168.1.1", cfg.Host)
assert.Equal(t, 9090, cfg.Port)
assert.Equal(t, 20, cfg.KeyLength)
assert.Equal(t, "mongodb", cfg.Storage.Type)
assert.Equal(t, "mongo.example.com", cfg.Storage.Host)
assert.Equal(t, 27017, cfg.Storage.Port)
assert.Equal(t, "warn", cfg.Logging.Level)
assert.Equal(t, "json", cfg.Logging.Type)
require.Equal(t, "192.168.1.1", cfg.Host)
require.Equal(t, 9090, cfg.Port)
require.Equal(t, 20, cfg.KeyLength)
require.Equal(t, "mongodb", cfg.Storage.Type)
require.Equal(t, "mongo.example.com", cfg.Storage.Host)
require.Equal(t, 27017, cfg.Storage.Port)
require.Equal(t, "warn", cfg.Logging.Level)
}

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"os"
"strconv"
"time"
"github.com/armbian/ansi-hastebin/config"
"github.com/armbian/ansi-hastebin/handler"
@@ -13,8 +14,9 @@ import (
"github.com/armbian/ansi-hastebin/internal/storage"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/httprate"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
"github.com/rs/zerolog/log"
)
type Server struct {
@@ -46,6 +48,11 @@ func (s *Server) RegisterRoutes() {
s.mux.Use(middleware.Logger)
s.mux.Use(middleware.Recoverer)
// Rate limiter
if s.config.RateLimiting.Enable {
s.mux.Use(httprate.LimitByRealIP(s.config.RateLimiting.Limit, time.Duration(s.config.RateLimiting.Window)*time.Second))
}
// Register promhttp middleware
s.mux.Get("/metrics", promhttp.Handler().ServeHTTP)
@@ -86,21 +93,21 @@ func (s *Server) RegisterRoutes() {
}
func (s *Server) Start() {
logrus.Infof("Starting server on %s", s.server.Addr)
log.Info().Str("host", s.config.Host).Int("port", s.config.Port).Msg("Starting server")
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logrus.WithError(err).Fatal("Failed to start server")
log.Fatal().Err(err).Msg("Failed to start server")
}
}
func (s *Server) Shutdown(ctx context.Context) {
logrus.Info("Gracefully shutting down server")
log.Info().Msg("Gracefully shutting down server")
if err := s.storage.Close(); err != nil {
logrus.WithError(err).Error("Failed to close storage")
log.Error().Err(err).Msg("Failed to close storage")
}
if err := s.server.Shutdown(ctx); err != nil {
logrus.WithError(err).Error("Failed to shutdown server")
log.Error().Err(err).Msg("Failed to shutdown server")
}
}