From 5331e149540dc0c3ebf79d8785a7047f161098df Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 16 Mar 2025 18:22:38 +0300 Subject: [PATCH] server: add support for rate limiting --- cmd/main.go | 24 +++++++++--- config.yaml | 5 +++ config/config.go | 78 ++++++++++++++++++++++++++++----------- config/config_test.go | 72 ++++++++++++++++++------------------ internal/server/server.go | 19 +++++++--- 5 files changed, 129 insertions(+), 69 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index 8c3b38b..49adc4c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -13,7 +13,8 @@ import ( "github.com/armbian/ansi-hastebin/internal/keygenerator" "github.com/armbian/ansi-hastebin/internal/server" "github.com/armbian/ansi-hastebin/internal/storage" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" ) func handleConfig(location string) (*config.Config, storage.Storage, keygenerator.KeyGenerator) { @@ -35,7 +36,7 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato case "s3": pasteStorage = storage.NewS3Storage(cfg.Storage.Host, cfg.Storage.Port, cfg.Storage.Username, cfg.Storage.Password, cfg.Storage.AWSRegion, cfg.Storage.Bucket) default: - logrus.Fatalf("Unknown storage type: %s", cfg.Storage.Type) + log.Fatal().Str("storage_type", cfg.Storage.Type).Msg("Unknown storage type") return nil, nil, nil } @@ -43,17 +44,17 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato for _, doc := range cfg.Documents { file, err := os.OpenFile(doc.Path, os.O_RDONLY, 0644) if err != nil { - logrus.WithError(err).WithField("path", doc.Path).Fatal("Failed to open document") + log.Fatal().Err(err).Str("path", doc.Path).Msg("Failed to open document") } content, err := io.ReadAll(file) if err != nil { - logrus.WithError(err).WithField("path", doc.Path).Fatal("Failed to read document") + log.Fatal().Err(err).Str("path", doc.Path).Msg("Failed to read document") } file.Close() if err := pasteStorage.Set(doc.Key, string(content), false); err != nil { - logrus.WithError(err).WithField("key", doc.Key).Fatal("Failed to set document") + log.Fatal().Err(err).Str("key", doc.Key).Msg("Failed to set document") } } @@ -65,10 +66,21 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato case "phonetic": keyGenerator = keygenerator.NewPhoneticKeyGenerator() default: - logrus.Fatalf("Unknown key generator: %s", cfg.KeyGenerator) + log.Fatal().Str("key_generator", cfg.KeyGenerator).Msg("Unknown key generator") return nil, nil, nil } + // Adjust logger + logLevel, err := zerolog.ParseLevel(cfg.Logging.Level) + if err != nil { + log.Fatal().Err(err).Str("level", cfg.Logging.Level).Msg("Failed to parse log level") + } + log.Logger = log.Level(logLevel) + + if cfg.Logging.Colorize { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout}) + } + return cfg, pasteStorage, keyGenerator } diff --git a/config.yaml b/config.yaml index c849f4f..11cecf7 100644 --- a/config.yaml +++ b/config.yaml @@ -15,6 +15,11 @@ documents: - key: "about" path: "./about.md" +rate_limiting: + enable: true + limit: 500 + window: 15 + logging: level: "info" type: "text" diff --git a/config/config.go b/config/config.go index 9438424..71dc5c7 100644 --- a/config/config.go +++ b/config/config.go @@ -5,14 +5,27 @@ import ( "strconv" "strings" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) type LoggingConfig struct { - Level string `yaml:"level"` - Type string `yaml:"type"` - Colorize bool `yaml:"colorize"` + // Level is the logging level to use + Level string `yaml:"level"` + + // Colorize is a flag to enable colorized output + Colorize bool `yaml:"colorize"` +} + +type RateLimitingConfig struct { + // Enable is a flag to enable rate limiting + Enable bool `yaml:"enable"` + + // Limit is the maximum number of requests + Limit int `yaml:"limit"` + + // Window is the time window to limit requests + Window int `yaml:"window"` } type StorageConfig struct { @@ -94,6 +107,9 @@ type Config struct { // Logging is the logging configuration Logging LoggingConfig `yaml:"logging"` + // RateLimiting is the rate limiting configuration + RateLimiting RateLimitingConfig `yaml:"rate_limiting"` + // Documents is the list of documents to load statically Documents []DocumentConfig `yaml:"documents"` } @@ -112,7 +128,6 @@ var DefaultConfig = &Config{ }, Logging: LoggingConfig{ Level: "info", - Type: "text", }, Documents: []DocumentConfig{ { @@ -129,12 +144,12 @@ func NewConfig(configFile string) *Config { // Read the configuration file data, err := os.ReadFile(configFile) if err != nil && !os.IsNotExist(err) { - logrus.WithError(err).Fatal("Failed to read configuration file") + log.Fatal().Err(err).Msg("Failed to read configuration file") } // Unmarshal the configuration file if err := yaml.Unmarshal(data, cfg); err != nil { - logrus.WithError(err).Fatal("Failed to unmarshal configuration file") + log.Fatal().Err(err).Msg("Failed to unmarshal configuration file") } // Override with environment variables @@ -145,7 +160,7 @@ func NewConfig(configFile string) *Config { if port := os.Getenv("PORT"); port != "" { portInt, err := strconv.Atoi(port) if err != nil { - logrus.WithError(err).Fatal("Failed to parse PORT environment variable") + log.Fatal().Err(err).Msg("Failed to parse PORT environment variable") } cfg.Port = portInt } @@ -153,7 +168,7 @@ func NewConfig(configFile string) *Config { if keyLength := os.Getenv("KEY_LENGTH"); keyLength != "" { keyLengthInt, err := strconv.Atoi(keyLength) if err != nil { - logrus.WithError(err).Fatal("Failed to parse KEY_LENGTH environment variable") + log.Fatal().Err(err).Msg("Failed to parse KEY_LENGTH environment variable") } cfg.KeyLength = keyLengthInt } @@ -161,7 +176,7 @@ func NewConfig(configFile string) *Config { if maxLength := os.Getenv("MAX_LENGTH"); maxLength != "" { maxLengthInt, err := strconv.Atoi(maxLength) if err != nil { - logrus.WithError(err).Fatal("Failed to parse MAX_LENGTH environment variable") + log.Fatal().Err(err).Msg("Failed to parse MAX_LENGTH environment variable") } cfg.MaxLength = maxLengthInt } @@ -169,7 +184,7 @@ func NewConfig(configFile string) *Config { if staticMaxAge := os.Getenv("STATIC_MAX_AGE"); staticMaxAge != "" { staticMaxAgeInt, err := strconv.Atoi(staticMaxAge) if err != nil { - logrus.WithError(err).Fatal("Failed to parse STATIC_MAX_AGE environment variable") + log.Fatal().Err(err).Msg("Failed to parse STATIC_MAX_AGE environment variable") } cfg.StaticMaxAge = staticMaxAgeInt } @@ -177,7 +192,7 @@ func NewConfig(configFile string) *Config { if recompressStaticAssets := os.Getenv("RECOMPRESS_STATIC_ASSETS"); recompressStaticAssets != "" { recompressStaticAssetsBool, err := strconv.ParseBool(recompressStaticAssets) if err != nil { - logrus.WithError(err).Fatal("Failed to parse RECOMPRESS_STATIC_ASSETS environment variable") + log.Fatal().Err(err).Msg("Failed to parse RECOMPRESS_STATIC_ASSETS environment variable") } cfg.RecompressStaticAssets = recompressStaticAssetsBool } @@ -197,7 +212,7 @@ func NewConfig(configFile string) *Config { if storagePort := os.Getenv("STORAGE_PORT"); storagePort != "" { storagePortInt, err := strconv.Atoi(storagePort) if err != nil { - logrus.WithError(err).Fatal("Failed to parse STORAGE_PORT environment variable") + log.Fatal().Err(err).Msg("Failed to parse STORAGE_PORT environment variable") } cfg.Storage.Port = storagePortInt } @@ -230,18 +245,41 @@ func NewConfig(configFile string) *Config { cfg.Logging.Level = loggingLevel } - if loggingType := os.Getenv("LOGGING_TYPE"); loggingType != "" { - cfg.Logging.Type = loggingType - } - if loggingColorize := os.Getenv("LOGGING_COLORIZE"); loggingColorize != "" { loggingColorizeBool, err := strconv.ParseBool(loggingColorize) if err != nil { - logrus.WithError(err).Fatal("Failed to parse LOGGING_COLORIZE environment variable") + log.Fatal().Err(err).Msg("Failed to parse LOGGING_COLORIZE environment variable") } cfg.Logging.Colorize = loggingColorizeBool } + if rateLimitingEnable := os.Getenv("RATE_LIMITING_ENABLE"); rateLimitingEnable != "" { + rateLimitingEnableBool, err := strconv.ParseBool(rateLimitingEnable) + if err != nil { + log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_ENABLE environment variable") + } + + cfg.RateLimiting.Enable = rateLimitingEnableBool + } + + if rateLimitingLimit := os.Getenv("RATE_LIMITING_LIMIT"); rateLimitingLimit != "" { + rateLimitingLimitInt, err := strconv.Atoi(rateLimitingLimit) + if err != nil { + log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_LIMIT environment variable") + } + + cfg.RateLimiting.Limit = rateLimitingLimitInt + } + + if rateLimitingWindow := os.Getenv("RATE_LIMITING_WINDOW"); rateLimitingWindow != "" { + rateLimitingWindowInt, err := strconv.Atoi(rateLimitingWindow) + if err != nil { + log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_WINDOW environment variable") + } + + cfg.RateLimiting.Window = rateLimitingWindowInt + } + // Walk environment variables for documents for _, env := range os.Environ() { if len(env) > 10 && env[:10] == "DOCUMENTS_" { @@ -292,9 +330,5 @@ func NewConfig(configFile string) *Config { cfg.Logging.Level = DefaultConfig.Logging.Level } - if cfg.Logging.Type == "" { - cfg.Logging.Type = DefaultConfig.Logging.Type - } - return cfg } diff --git a/config/config_test.go b/config/config_test.go index 5406ed9..c0b530c 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -4,44 +4,47 @@ import ( "os" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewConfig_DefaultValues(t *testing.T) { cfg := NewConfig("nonexistent.yaml") // Should use defaults since file doesn't exist - assert.Equal(t, "0.0.0.0", cfg.Host) - assert.Equal(t, 7777, cfg.Port) - assert.Equal(t, 10, cfg.KeyLength) - assert.Equal(t, "phonetic", cfg.KeyGenerator) - assert.Equal(t, "file", cfg.Storage.Type) - assert.Equal(t, "data", cfg.Storage.FilePath) - assert.Equal(t, "info", cfg.Logging.Level) - assert.Equal(t, "text", cfg.Logging.Type) + require.Equal(t, "0.0.0.0", cfg.Host) + require.Equal(t, 7777, cfg.Port) + require.Equal(t, 10, cfg.KeyLength) + require.Equal(t, "phonetic", cfg.KeyGenerator) + require.Equal(t, "file", cfg.Storage.Type) + require.Equal(t, "data", cfg.Storage.FilePath) + require.Equal(t, "info", cfg.Logging.Level) } func TestNewConfig_OverrideWithEnvVars(t *testing.T) { - os.Setenv("HOST", "127.0.0.1") - os.Setenv("PORT", "8080") - os.Setenv("KEY_LENGTH", "15") - os.Setenv("MAX_LENGTH", "5000000") - os.Setenv("STORAGE_TYPE", "redis") - os.Setenv("STORAGE_HOST", "localhost") - os.Setenv("STORAGE_PORT", "6379") - os.Setenv("LOGGING_LEVEL", "debug") + t.Setenv("HOST", "127.0.0.1") + t.Setenv("PORT", "8080") + t.Setenv("KEY_LENGTH", "15") + t.Setenv("MAX_LENGTH", "5000000") + t.Setenv("STORAGE_TYPE", "redis") + t.Setenv("STORAGE_HOST", "localhost") + t.Setenv("STORAGE_PORT", "6379") + t.Setenv("LOGGING_LEVEL", "debug") + t.Setenv("RATE_LIMITING_ENABLE", "true") + t.Setenv("RATE_LIMITING_LIMIT", "100") defer os.Clearenv() cfg := NewConfig("nonexistent.yaml") // Load with environment variables - assert.Equal(t, "127.0.0.1", cfg.Host) - assert.Equal(t, 8080, cfg.Port) - assert.Equal(t, 15, cfg.KeyLength) - assert.Equal(t, 5000000, cfg.MaxLength) - assert.Equal(t, "redis", cfg.Storage.Type) - assert.Equal(t, "localhost", cfg.Storage.Host) - assert.Equal(t, 6379, cfg.Storage.Port) - assert.Equal(t, "debug", cfg.Logging.Level) + require.Equal(t, "127.0.0.1", cfg.Host) + require.Equal(t, 8080, cfg.Port) + require.Equal(t, 15, cfg.KeyLength) + require.Equal(t, 5000000, cfg.MaxLength) + require.Equal(t, "redis", cfg.Storage.Type) + require.Equal(t, "localhost", cfg.Storage.Host) + require.Equal(t, 6379, cfg.Storage.Port) + require.Equal(t, "debug", cfg.Logging.Level) + require.Equal(t, true, cfg.RateLimiting.Enable) + require.Equal(t, 100, cfg.RateLimiting.Limit) } func TestNewConfig_LoadFromYAML(t *testing.T) { @@ -60,21 +63,20 @@ logging: // Write to a temporary file tmpFile, err := os.CreateTemp("", "config_test_*.yaml") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(tmpFile.Name()) _, err = tmpFile.Write([]byte(yamlContent)) - assert.NoError(t, err) + require.NoError(t, err) tmpFile.Close() cfg := NewConfig(tmpFile.Name()) // Load config from YAML file - assert.Equal(t, "192.168.1.1", cfg.Host) - assert.Equal(t, 9090, cfg.Port) - assert.Equal(t, 20, cfg.KeyLength) - assert.Equal(t, "mongodb", cfg.Storage.Type) - assert.Equal(t, "mongo.example.com", cfg.Storage.Host) - assert.Equal(t, 27017, cfg.Storage.Port) - assert.Equal(t, "warn", cfg.Logging.Level) - assert.Equal(t, "json", cfg.Logging.Type) + require.Equal(t, "192.168.1.1", cfg.Host) + require.Equal(t, 9090, cfg.Port) + require.Equal(t, 20, cfg.KeyLength) + require.Equal(t, "mongodb", cfg.Storage.Type) + require.Equal(t, "mongo.example.com", cfg.Storage.Host) + require.Equal(t, 27017, cfg.Storage.Port) + require.Equal(t, "warn", cfg.Logging.Level) } diff --git a/internal/server/server.go b/internal/server/server.go index cf0980a..ac0624d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "strconv" + "time" "github.com/armbian/ansi-hastebin/config" "github.com/armbian/ansi-hastebin/handler" @@ -13,8 +14,9 @@ import ( "github.com/armbian/ansi-hastebin/internal/storage" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/httprate" "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" ) type Server struct { @@ -46,6 +48,11 @@ func (s *Server) RegisterRoutes() { s.mux.Use(middleware.Logger) s.mux.Use(middleware.Recoverer) + // Rate limiter + if s.config.RateLimiting.Enable { + s.mux.Use(httprate.LimitByRealIP(s.config.RateLimiting.Limit, time.Duration(s.config.RateLimiting.Window)*time.Second)) + } + // Register promhttp middleware s.mux.Get("/metrics", promhttp.Handler().ServeHTTP) @@ -86,21 +93,21 @@ func (s *Server) RegisterRoutes() { } func (s *Server) Start() { - logrus.Infof("Starting server on %s", s.server.Addr) + log.Info().Str("host", s.config.Host).Int("port", s.config.Port).Msg("Starting server") if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logrus.WithError(err).Fatal("Failed to start server") + log.Fatal().Err(err).Msg("Failed to start server") } } func (s *Server) Shutdown(ctx context.Context) { - logrus.Info("Gracefully shutting down server") + log.Info().Msg("Gracefully shutting down server") if err := s.storage.Close(); err != nil { - logrus.WithError(err).Error("Failed to close storage") + log.Error().Err(err).Msg("Failed to close storage") } if err := s.server.Shutdown(ctx); err != nil { - logrus.WithError(err).Error("Failed to shutdown server") + log.Error().Err(err).Msg("Failed to shutdown server") } }