github.com/labstack/echo/v4
Advanced tools
+89
-20
| # Changelog | ||
| ## v4.15.0 - TBD | ||
| ## v4.15.0 - 2026-01-01 | ||
| **Security** | ||
| NB: **If your application relies on cross-origin or same-site (same subdomain) requests do not blindly push this version to production** | ||
| The CSRF middleware now supports the [**Sec-Fetch-Site**](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site) header as a modern, defense-in-depth approach to [CSRF | ||
| protection](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers), implementing the OWASP-recommended Fetch Metadata API alongside the traditional token-based mechanism. | ||
| **How it works:** | ||
| Modern browsers automatically send the `Sec-Fetch-Site` header with all requests, indicating the relationship | ||
| between the request origin and the target. The middleware uses this to make security decisions: | ||
| - **`same-origin`** or **`none`**: Requests are allowed (exact origin match or direct user navigation) | ||
| - **`same-site`**: Falls back to token validation (e.g., subdomain to main domain) | ||
| - **`cross-site`**: Blocked by default with 403 error for unsafe methods (POST, PUT, DELETE, PATCH) | ||
| For browsers that don't send this header (older browsers), the middleware seamlessly falls back to | ||
| traditional token-based CSRF protection. | ||
| **New Configuration Options:** | ||
| - `TrustedOrigins []string`: Allowlist specific origins for cross-site requests (useful for OAuth callbacks, webhooks) | ||
| - `AllowSecFetchSiteFunc func(echo.Context) (bool, error)`: Custom logic for same-site/cross-site request validation | ||
| **Example:** | ||
| ```go | ||
| e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ | ||
| // Allow OAuth callbacks from trusted provider | ||
| TrustedOrigins: []string{"https://oauth-provider.com"}, | ||
| // Custom validation for same-site requests | ||
| AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { | ||
| // Your custom authorization logic here | ||
| return validateCustomAuth(c), nil | ||
| // return true, err // blocks request with error | ||
| // return true, nil // allows CSRF request through | ||
| // return false, nil // falls back to legacy token logic | ||
| }, | ||
| })) | ||
| ``` | ||
| PR: https://github.com/labstack/echo/pull/2858 | ||
| **Type-Safe Generic Parameter Binding** | ||
| * Added generic functions for type-safe parameter extraction and context access by @aldas in https://github.com/labstack/echo/pull/2856 | ||
| Echo now provides generic functions for extracting path, query, and form parameters with automatic type conversion, | ||
| eliminating manual string parsing and type assertions. | ||
| **New Functions:** | ||
| - Path parameters: `PathParam[T]`, `PathParamOr[T]` | ||
| - Query parameters: `QueryParam[T]`, `QueryParamOr[T]`, `QueryParams[T]`, `QueryParamsOr[T]` | ||
| - Form values: `FormParam[T]`, `FormParamOr[T]`, `FormParams[T]`, `FormParamsOr[T]` | ||
| - Context store: `ContextGet[T]`, `ContextGetOr[T]` | ||
| **Supported Types:** | ||
| Primitives (`bool`, `string`, `int`/`uint` variants, `float32`/`float64`), `time.Duration`, `time.Time` | ||
| (with custom layouts and Unix timestamp support), and custom types implementing `BindUnmarshaler`, | ||
| `TextUnmarshaler`, or `JSONUnmarshaler`. | ||
| **Example:** | ||
| ```go | ||
| // Before: Manual parsing | ||
| idStr := c.Param("id") | ||
| id, err := strconv.Atoi(idStr) | ||
| // After: Type-safe with automatic parsing | ||
| id, err := echo.PathParam[int](c, "id") | ||
| // With default values | ||
| page, err := echo.QueryParamOr[int](c, "page", 1) | ||
| limit, err := echo.QueryParamOr[int](c, "limit", 20) | ||
| // Type-safe context access (no more panics from type assertions) | ||
| user, err := echo.ContextGet[*User](c, "user") | ||
| ``` | ||
| PR: https://github.com/labstack/echo/pull/2856 | ||
| **DEPRECATION NOTICE** Timeout Middleware Deprecated - Use ContextTimeout Instead | ||
@@ -40,21 +122,2 @@ | ||
| With configuration: | ||
| ```go | ||
| // Before (deprecated): | ||
| e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{ | ||
| Timeout: 30 * time.Second, | ||
| Skipper: func(c echo.Context) bool { | ||
| return c.Path() == "/health" | ||
| }, | ||
| })) | ||
| // After (recommended): | ||
| e.Use(middleware.ContextTimeoutWithConfig(middleware.ContextTimeoutConfig{ | ||
| Timeout: 30 * time.Second, | ||
| Skipper: func(c echo.Context) bool { | ||
| return c.Path() == "/health" | ||
| }, | ||
| })) | ||
| ``` | ||
| **Important Behavioral Differences:** | ||
@@ -116,3 +179,9 @@ | ||
| **Enhancements** | ||
| * Fixes by @aldas in https://github.com/labstack/echo/pull/2852 | ||
| * Generic functions by @aldas in https://github.com/labstack/echo/pull/2856 | ||
| * CRSF with Sec-Fetch-Site checks by @aldas in https://github.com/labstack/echo/pull/2858 | ||
| ## v4.14.0 - 2025-12-11 | ||
@@ -119,0 +188,0 @@ |
+12
-4
@@ -235,6 +235,9 @@ // SPDX-License-Identifier: MIT | ||
| HeaderServer = "Server" | ||
| HeaderOrigin = "Origin" | ||
| HeaderCacheControl = "Cache-Control" | ||
| HeaderConnection = "Connection" | ||
| // HeaderOrigin request header indicates the origin (scheme, hostname, and port) that caused the request. | ||
| // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin | ||
| HeaderOrigin = "Origin" | ||
| HeaderCacheControl = "Cache-Control" | ||
| HeaderConnection = "Connection" | ||
| // Access control | ||
@@ -259,2 +262,7 @@ HeaderAccessControlRequestMethod = "Access-Control-Request-Method" | ||
| HeaderReferrerPolicy = "Referrer-Policy" | ||
| // HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's | ||
| // origin and the origin of the requested resource. | ||
| // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site | ||
| HeaderSecFetchSite = "Sec-Fetch-Site" | ||
| ) | ||
@@ -264,3 +272,3 @@ | ||
| // Version of Echo | ||
| Version = "4.14.0" | ||
| Version = "4.15.0" | ||
| website = "https://echo.labstack.com" | ||
@@ -267,0 +275,0 @@ // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo |
+502
-34
@@ -7,2 +7,3 @@ // SPDX-License-Identifier: MIT | ||
| import ( | ||
| "cmp" | ||
| "net/http" | ||
@@ -20,11 +21,12 @@ "net/http/httptest" | ||
| var testCases = []struct { | ||
| name string | ||
| whenTokenLookup string | ||
| whenCookieName string | ||
| givenCSRFCookie string | ||
| givenMethod string | ||
| givenQueryTokens map[string][]string | ||
| givenFormTokens map[string][]string | ||
| givenHeaderTokens map[string][]string | ||
| expectError string | ||
| name string | ||
| whenTokenLookup string | ||
| whenCookieName string | ||
| givenCSRFCookie string | ||
| givenMethod string | ||
| givenQueryTokens map[string][]string | ||
| givenFormTokens map[string][]string | ||
| givenHeaderTokens map[string][]string | ||
| expectError string | ||
| expectToMiddlewareError string | ||
| }{ | ||
@@ -151,2 +153,10 @@ { | ||
| }, | ||
| { | ||
| name: "nok, invalid TokenLookup", | ||
| whenTokenLookup: "q", | ||
| givenCSRFCookie: "token", | ||
| givenMethod: http.MethodPut, | ||
| givenQueryTokens: map[string][]string{}, | ||
| expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", | ||
| }, | ||
| } | ||
@@ -194,6 +204,13 @@ | ||
| csrf := CSRFWithConfig(CSRFConfig{ | ||
| config := CSRFConfig{ | ||
| TokenLookup: tc.whenTokenLookup, | ||
| CookieName: tc.whenCookieName, | ||
| }) | ||
| } | ||
| csrf, err := config.ToMiddleware() | ||
| if tc.expectToMiddlewareError != "" { | ||
| assert.EqualError(t, err, tc.expectToMiddlewareError) | ||
| return | ||
| } else if err != nil { | ||
| assert.NoError(t, err) | ||
| } | ||
@@ -204,3 +221,3 @@ h := csrf(func(c echo.Context) error { | ||
| err := h(c) | ||
| err = h(c) | ||
| if tc.expectError != "" { | ||
@@ -215,2 +232,121 @@ assert.EqualError(t, err, tc.expectError) | ||
| func TestCSRFWithConfig(t *testing.T) { | ||
| token := randomString(16) | ||
| var testCases = []struct { | ||
| name string | ||
| givenConfig *CSRFConfig | ||
| whenMethod string | ||
| whenHeaders map[string]string | ||
| expectEmptyBody bool | ||
| expectMWError string | ||
| expectCookieContains string | ||
| expectErr string | ||
| }{ | ||
| { | ||
| name: "ok, GET", | ||
| whenMethod: http.MethodGet, | ||
| expectCookieContains: "_csrf", | ||
| }, | ||
| { | ||
| name: "ok, POST valid token", | ||
| whenHeaders: map[string]string{ | ||
| echo.HeaderCookie: "_csrf=" + token, | ||
| echo.HeaderXCSRFToken: token, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| expectCookieContains: "_csrf", | ||
| }, | ||
| { | ||
| name: "nok, POST without token", | ||
| whenMethod: http.MethodPost, | ||
| expectEmptyBody: true, | ||
| expectErr: `code=400, message=missing csrf token in request header`, | ||
| }, | ||
| { | ||
| name: "nok, POST empty token", | ||
| whenHeaders: map[string]string{echo.HeaderXCSRFToken: ""}, | ||
| whenMethod: http.MethodPost, | ||
| expectEmptyBody: true, | ||
| expectErr: `code=403, message=invalid csrf token`, | ||
| }, | ||
| { | ||
| name: "nok, invalid trusted origin in Config", | ||
| givenConfig: &CSRFConfig{ | ||
| TrustedOrigins: []string{"http://example.com", "invalid"}, | ||
| }, | ||
| expectMWError: `trusted origin is missing scheme or host: invalid`, | ||
| }, | ||
| { | ||
| name: "ok, TokenLength", | ||
| givenConfig: &CSRFConfig{ | ||
| TokenLength: 16, | ||
| }, | ||
| whenMethod: http.MethodGet, | ||
| expectCookieContains: "_csrf", | ||
| }, | ||
| { | ||
| name: "ok, unsafe method + SecFetchSite=same-origin passes", | ||
| whenHeaders: map[string]string{ | ||
| echo.HeaderSecFetchSite: "same-origin", | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| }, | ||
| { | ||
| name: "nok, unsafe method + SecFetchSite=same-cross blocked", | ||
| whenHeaders: map[string]string{ | ||
| echo.HeaderSecFetchSite: "same-cross", | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| expectEmptyBody: true, | ||
| expectErr: `code=403, message=cross-site request blocked by CSRF`, | ||
| }, | ||
| } | ||
| for _, tc := range testCases { | ||
| t.Run(tc.name, func(t *testing.T) { | ||
| e := echo.New() | ||
| req := httptest.NewRequest(cmp.Or(tc.whenMethod, http.MethodPost), "/", nil) | ||
| rec := httptest.NewRecorder() | ||
| c := e.NewContext(req, rec) | ||
| for key, value := range tc.whenHeaders { | ||
| req.Header.Set(key, value) | ||
| } | ||
| config := CSRFConfig{} | ||
| if tc.givenConfig != nil { | ||
| config = *tc.givenConfig | ||
| } | ||
| mw, err := config.ToMiddleware() | ||
| if tc.expectMWError != "" { | ||
| assert.EqualError(t, err, tc.expectMWError) | ||
| return | ||
| } | ||
| assert.NoError(t, err) | ||
| h := mw(func(c echo.Context) error { | ||
| return c.String(http.StatusOK, "test") | ||
| }) | ||
| err = h(c) | ||
| if tc.expectErr != "" { | ||
| assert.EqualError(t, err, tc.expectErr) | ||
| } else { | ||
| assert.NoError(t, err) | ||
| } | ||
| expect := "test" | ||
| if tc.expectEmptyBody { | ||
| expect = "" | ||
| } | ||
| assert.Equal(t, expect, rec.Body.String()) | ||
| if tc.expectCookieContains != "" { | ||
| assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), tc.expectCookieContains) | ||
| } | ||
| }) | ||
| } | ||
| } | ||
| func TestCSRF(t *testing.T) { | ||
@@ -230,22 +366,2 @@ e := echo.New() | ||
| // Without CSRF cookie | ||
| req = httptest.NewRequest(http.MethodPost, "/", nil) | ||
| rec = httptest.NewRecorder() | ||
| c = e.NewContext(req, rec) | ||
| assert.Error(t, h(c)) | ||
| // Empty/invalid CSRF token | ||
| req = httptest.NewRequest(http.MethodPost, "/", nil) | ||
| rec = httptest.NewRecorder() | ||
| c = e.NewContext(req, rec) | ||
| req.Header.Set(echo.HeaderXCSRFToken, "") | ||
| assert.Error(t, h(c)) | ||
| // Valid CSRF token | ||
| token := randomString(32) | ||
| req.Header.Set(echo.HeaderCookie, "_csrf="+token) | ||
| req.Header.Set(echo.HeaderXCSRFToken, token) | ||
| if assert.NoError(t, h(c)) { | ||
| assert.Equal(t, http.StatusOK, rec.Code) | ||
| } | ||
| } | ||
@@ -314,5 +430,6 @@ | ||
| csrf := CSRFWithConfig(CSRFConfig{ | ||
| csrf, err := CSRFConfig{ | ||
| CookieSameSite: http.SameSiteNoneMode, | ||
| }) | ||
| }.ToMiddleware() | ||
| assert.NoError(t, err) | ||
@@ -393,1 +510,352 @@ h := csrf(func(c echo.Context) error { | ||
| } | ||
| func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { | ||
| var testCases = []struct { | ||
| name string | ||
| givenConfig CSRFConfig | ||
| whenMethod string | ||
| whenSecFetchSite string | ||
| whenOrigin string | ||
| expectAllow bool | ||
| expectErr string | ||
| }{ | ||
| { | ||
| name: "ok, unsafe POST, no SecFetchSite is not blocked", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "", | ||
| expectAllow: false, // should fall back to token CSRF | ||
| }, | ||
| { | ||
| name: "ok, safe GET + same-origin passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodGet, | ||
| whenSecFetchSite: "same-origin", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, safe GET + none passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodGet, | ||
| whenSecFetchSite: "none", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, safe GET + same-site passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodGet, | ||
| whenSecFetchSite: "same-site", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, safe GET + cross-site passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodGet, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "nok, unsafe POST + cross-site is blocked", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: false, | ||
| expectErr: `code=403, message=cross-site request blocked by CSRF`, | ||
| }, | ||
| { | ||
| name: "nok, unsafe POST + same-site is blocked", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "same-site", | ||
| expectAllow: false, | ||
| expectErr: ``, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + same-origin passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "same-origin", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + none passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "none", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe PUT + same-origin passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPut, | ||
| whenSecFetchSite: "same-origin", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe PUT + none passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPut, | ||
| whenSecFetchSite: "none", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe DELETE + same-origin passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodDelete, | ||
| whenSecFetchSite: "same-origin", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe PATCH + same-origin passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPatch, | ||
| whenSecFetchSite: "same-origin", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "nok, unsafe PUT + cross-site is blocked", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPut, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: false, | ||
| expectErr: `code=403, message=cross-site request blocked by CSRF`, | ||
| }, | ||
| { | ||
| name: "nok, unsafe PUT + same-site is blocked", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPut, | ||
| whenSecFetchSite: "same-site", | ||
| expectAllow: false, | ||
| expectErr: ``, | ||
| }, | ||
| { | ||
| name: "nok, unsafe DELETE + cross-site is blocked", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodDelete, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: false, | ||
| expectErr: `code=403, message=cross-site request blocked by CSRF`, | ||
| }, | ||
| { | ||
| name: "nok, unsafe DELETE + same-site is blocked", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodDelete, | ||
| whenSecFetchSite: "same-site", | ||
| expectAllow: false, | ||
| expectErr: ``, | ||
| }, | ||
| { | ||
| name: "nok, unsafe PATCH + cross-site is blocked", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPatch, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: false, | ||
| expectErr: `code=403, message=cross-site request blocked by CSRF`, | ||
| }, | ||
| { | ||
| name: "ok, safe HEAD + same-origin passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodHead, | ||
| whenSecFetchSite: "same-origin", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, safe HEAD + cross-site passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodHead, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, safe OPTIONS + cross-site passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodOptions, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, safe TRACE + cross-site passes", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodTrace, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + cross-site + matching trusted origin passes", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://trusted.example.com"}, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| whenOrigin: "https://trusted.example.com", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + same-site + matching trusted origin passes", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://trusted.example.com"}, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "same-site", | ||
| whenOrigin: "https://trusted.example.com", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "nok, unsafe POST + cross-site + non-matching origin is blocked", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://trusted.example.com"}, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| whenOrigin: "https://evil.example.com", | ||
| expectAllow: false, | ||
| expectErr: `code=403, message=cross-site request blocked by CSRF`, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + cross-site + case-insensitive trusted origin match passes", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://trusted.example.com"}, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| whenOrigin: "https://TRUSTED.example.com", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + same-origin + trusted origins configured but not matched passes", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://trusted.example.com"}, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "same-origin", | ||
| whenOrigin: "https://different.example.com", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "nok, unsafe POST + cross-site + empty origin + trusted origins configured is blocked", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://trusted.example.com"}, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| whenOrigin: "", | ||
| expectAllow: false, | ||
| expectErr: `code=403, message=cross-site request blocked by CSRF`, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + cross-site + multiple trusted origins, second one matches", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://first.example.com", "https://second.example.com"}, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| whenOrigin: "https://second.example.com", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + same-site + custom func allows", | ||
| givenConfig: CSRFConfig{ | ||
| AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { | ||
| return true, nil | ||
| }, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "same-site", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + cross-site + custom func allows", | ||
| givenConfig: CSRFConfig{ | ||
| AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { | ||
| return true, nil | ||
| }, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "nok, unsafe POST + same-site + custom func returns custom error", | ||
| givenConfig: CSRFConfig{ | ||
| AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { | ||
| return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func") | ||
| }, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "same-site", | ||
| expectAllow: false, | ||
| expectErr: `code=418, message=custom error from func`, | ||
| }, | ||
| { | ||
| name: "nok, unsafe POST + cross-site + custom func returns false with nil error", | ||
| givenConfig: CSRFConfig{ | ||
| AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { | ||
| return false, nil | ||
| }, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| expectAllow: false, | ||
| expectErr: "", // custom func returns nil error, so no error expected | ||
| }, | ||
| { | ||
| name: "nok, unsafe POST + invalid Sec-Fetch-Site value treated as cross-site", | ||
| givenConfig: CSRFConfig{}, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "invalid-value", | ||
| expectAllow: false, | ||
| expectErr: `code=403, message=cross-site request blocked by CSRF`, | ||
| }, | ||
| { | ||
| name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://trusted.example.com"}, | ||
| AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { | ||
| return false, echo.NewHTTPError(http.StatusTeapot, "should not be called") | ||
| }, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| whenOrigin: "https://trusted.example.com", | ||
| expectAllow: true, | ||
| }, | ||
| { | ||
| name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks", | ||
| givenConfig: CSRFConfig{ | ||
| TrustedOrigins: []string{"https://trusted.example.com"}, | ||
| AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { | ||
| return false, echo.NewHTTPError(http.StatusTeapot, "custom block") | ||
| }, | ||
| }, | ||
| whenMethod: http.MethodPost, | ||
| whenSecFetchSite: "cross-site", | ||
| whenOrigin: "https://evil.example.com", | ||
| expectAllow: false, | ||
| expectErr: `code=418, message=custom block`, | ||
| }, | ||
| } | ||
| for _, tc := range testCases { | ||
| t.Run(tc.name, func(t *testing.T) { | ||
| req := httptest.NewRequest(tc.whenMethod, "/", nil) | ||
| if tc.whenSecFetchSite != "" { | ||
| req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite) | ||
| } | ||
| if tc.whenOrigin != "" { | ||
| req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) | ||
| } | ||
| res := httptest.NewRecorder() | ||
| e := echo.New() | ||
| c := e.NewContext(req, res) | ||
| allow, err := tc.givenConfig.checkSecFetchSiteRequest(c) | ||
| assert.Equal(t, tc.expectAllow, allow) | ||
| if tc.expectErr != "" { | ||
| assert.EqualError(t, err, tc.expectErr) | ||
| } else { | ||
| assert.NoError(t, err) | ||
| } | ||
| }) | ||
| } | ||
| } |
+88
-3
@@ -9,2 +9,4 @@ // SPDX-License-Identifier: MIT | ||
| "net/http" | ||
| "slices" | ||
| "strings" | ||
| "time" | ||
@@ -20,2 +22,18 @@ | ||
| // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header | ||
| // exactly matches the specified value. | ||
| // Values should be formated as Origin header "scheme://host[:port]". | ||
| // | ||
| // See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin | ||
| // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers | ||
| TrustedOrigins []string | ||
| // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to | ||
| // fail with CRSF error, to be allowed or replaced with custom error. | ||
| // This function applies to `Sec-Fetch-Site` values: | ||
| // - `same-site` same registrable domain (subdomain and/or different port) | ||
| // - `cross-site` request originates from different site | ||
| // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers | ||
| AllowSecFetchSiteFunc func(c echo.Context) (bool, error) | ||
| // TokenLength is the length of the generated token. | ||
@@ -99,3 +117,7 @@ TokenLength uint8 `yaml:"token_length"` | ||
| func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { | ||
| // Defaults | ||
| return toMiddlewareOrPanic(config) | ||
| } | ||
| // ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration | ||
| func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { | ||
| if config.Skipper == nil { | ||
@@ -123,6 +145,12 @@ config.Skipper = DefaultCSRFConfig.Skipper | ||
| } | ||
| if len(config.TrustedOrigins) > 0 { | ||
| if vErr := validateOrigins(config.TrustedOrigins, "trusted origin"); vErr != nil { | ||
| return nil, vErr | ||
| } | ||
| config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...) | ||
| } | ||
| extractors, cErr := CreateExtractors(config.TokenLookup) | ||
| if cErr != nil { | ||
| panic(cErr) | ||
| return nil, cErr | ||
| } | ||
@@ -136,2 +164,13 @@ | ||
| // use the `Sec-Fetch-Site` header as part of a modern approach to CSRF protection | ||
| allow, err := config.checkSecFetchSiteRequest(c) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| if allow { | ||
| return next(c) | ||
| } | ||
| // Fallback to legacy token based CSRF protection | ||
| token := "" | ||
@@ -218,3 +257,3 @@ if k, err := c.Cookie(config.CookieName); err != nil { | ||
| } | ||
| } | ||
| }, nil | ||
| } | ||
@@ -225,1 +264,47 @@ | ||
| } | ||
| var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace} | ||
| func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error) { | ||
| // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers | ||
| // Sec-Fetch-Site values are: | ||
| // - `same-origin` exact origin match - allow always | ||
| // - `same-site` same registrable domain (subdomain and/or different port) - block, unless explicitly trusted | ||
| // - `cross-site` request originates from different site - block, unless explicitly trusted | ||
| // - `none` direct navigation (URL bar, bookmark) - allow always | ||
| secFetchSite := c.Request().Header.Get(echo.HeaderSecFetchSite) | ||
| if secFetchSite == "" { | ||
| return false, nil | ||
| } | ||
| if len(config.TrustedOrigins) > 0 { | ||
| // trusted sites ala OAuth callbacks etc. should be let through | ||
| origin := c.Request().Header.Get(echo.HeaderOrigin) | ||
| if origin != "" { | ||
| for _, trustedOrigin := range config.TrustedOrigins { | ||
| if strings.EqualFold(origin, trustedOrigin) { | ||
| return true, nil | ||
| } | ||
| } | ||
| } | ||
| } | ||
| isSafe := slices.Contains(safeMethods, c.Request().Method) | ||
| if !isSafe { // for state-changing request check SecFetchSite value | ||
| isSafe = secFetchSite == "same-origin" || secFetchSite == "none" | ||
| } | ||
| if isSafe { | ||
| return true, nil | ||
| } | ||
| // we are here when request is state-changing and `cross-site` or `same-site` | ||
| // Note: if you want to block `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc` | ||
| if config.AllowSecFetchSiteFunc != nil { | ||
| return config.AllowSecFetchSiteFunc(c) | ||
| } | ||
| if secFetchSite == "same-site" { | ||
| return false, nil // fall back to legacy token | ||
| } | ||
| return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF") | ||
| } |
@@ -91,1 +91,11 @@ // SPDX-License-Identifier: MIT | ||
| } | ||
| func toMiddlewareOrPanic(config interface { | ||
| ToMiddleware() (echo.MiddlewareFunc, error) | ||
| }) echo.MiddlewareFunc { | ||
| mw, err := config.ToMiddleware() | ||
| if err != nil { | ||
| panic(err) | ||
| } | ||
| return mw | ||
| } |
+206
-0
@@ -152,1 +152,207 @@ // SPDX-License-Identifier: MIT | ||
| } | ||
| func TestValidateOrigins(t *testing.T) { | ||
| var testCases = []struct { | ||
| name string | ||
| givenOrigins []string | ||
| givenWhat string | ||
| expectErr string | ||
| }{ | ||
| // Valid cases | ||
| { | ||
| name: "ok, empty origins", | ||
| givenOrigins: []string{}, | ||
| }, | ||
| { | ||
| name: "ok, basic http", | ||
| givenOrigins: []string{"http://example.com"}, | ||
| }, | ||
| { | ||
| name: "ok, basic https", | ||
| givenOrigins: []string{"https://example.com"}, | ||
| }, | ||
| { | ||
| name: "ok, with port", | ||
| givenOrigins: []string{"http://localhost:8080"}, | ||
| }, | ||
| { | ||
| name: "ok, with subdomain", | ||
| givenOrigins: []string{"https://api.example.com"}, | ||
| }, | ||
| { | ||
| name: "ok, subdomain with port", | ||
| givenOrigins: []string{"https://api.example.com:8080"}, | ||
| }, | ||
| { | ||
| name: "ok, localhost", | ||
| givenOrigins: []string{"http://localhost"}, | ||
| }, | ||
| { | ||
| name: "ok, IPv4 address", | ||
| givenOrigins: []string{"http://192.168.1.1"}, | ||
| }, | ||
| { | ||
| name: "ok, IPv4 with port", | ||
| givenOrigins: []string{"http://192.168.1.1:8080"}, | ||
| }, | ||
| { | ||
| name: "ok, IPv6 loopback", | ||
| givenOrigins: []string{"http://[::1]"}, | ||
| }, | ||
| { | ||
| name: "ok, IPv6 with port", | ||
| givenOrigins: []string{"http://[::1]:8080"}, | ||
| }, | ||
| { | ||
| name: "ok, IPv6 full address", | ||
| givenOrigins: []string{"http://[2001:db8::1]"}, | ||
| }, | ||
| { | ||
| name: "ok, multiple valid origins", | ||
| givenOrigins: []string{"http://example.com", "https://api.example.com:8080"}, | ||
| }, | ||
| { | ||
| name: "ok, different schemes", | ||
| givenOrigins: []string{"http://example.com", "https://example.com", "ws://example.com"}, | ||
| }, | ||
| // Invalid - missing scheme | ||
| { | ||
| name: "nok, plain domain", | ||
| givenOrigins: []string{"example.com"}, | ||
| expectErr: "trusted origin is missing scheme or host: example.com", | ||
| }, | ||
| { | ||
| name: "nok, with slashes but no scheme", | ||
| givenOrigins: []string{"//example.com"}, | ||
| expectErr: "trusted origin is missing scheme or host: //example.com", | ||
| }, | ||
| { | ||
| name: "nok, www without scheme", | ||
| givenOrigins: []string{"www.example.com"}, | ||
| expectErr: "trusted origin is missing scheme or host: www.example.com", | ||
| }, | ||
| { | ||
| name: "nok, localhost without scheme", | ||
| givenOrigins: []string{"localhost:8080"}, | ||
| expectErr: "trusted origin is missing scheme or host: localhost:8080", | ||
| }, | ||
| // Invalid - missing host | ||
| { | ||
| name: "nok, scheme only http", | ||
| givenOrigins: []string{"http://"}, | ||
| expectErr: "trusted origin is missing scheme or host: http://", | ||
| }, | ||
| { | ||
| name: "nok, scheme only https", | ||
| givenOrigins: []string{"https://"}, | ||
| expectErr: "trusted origin is missing scheme or host: https://", | ||
| }, | ||
| // Invalid - has path | ||
| { | ||
| name: "nok, has simple path", | ||
| givenOrigins: []string{"http://example.com/path"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: http://example.com/path", | ||
| }, | ||
| { | ||
| name: "nok, has nested path", | ||
| givenOrigins: []string{"https://example.com/api/v1"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: https://example.com/api/v1", | ||
| }, | ||
| { | ||
| name: "nok, has root path", | ||
| givenOrigins: []string{"http://example.com/"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: http://example.com/", | ||
| }, | ||
| // Invalid - has query | ||
| { | ||
| name: "nok, has single query param", | ||
| givenOrigins: []string{"http://example.com?foo=bar"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: http://example.com?foo=bar", | ||
| }, | ||
| { | ||
| name: "nok, has multiple query params", | ||
| givenOrigins: []string{"https://example.com?foo=bar&baz=qux"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: https://example.com?foo=bar&baz=qux", | ||
| }, | ||
| // Invalid - has fragment | ||
| { | ||
| name: "nok, has simple fragment", | ||
| givenOrigins: []string{"http://example.com#section"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: http://example.com#section", | ||
| }, | ||
| // Invalid - combinations | ||
| { | ||
| name: "nok, has path and query", | ||
| givenOrigins: []string{"http://example.com/path?foo=bar"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: http://example.com/path?foo=bar", | ||
| }, | ||
| { | ||
| name: "nok, has path and fragment", | ||
| givenOrigins: []string{"http://example.com/path#section"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: http://example.com/path#section", | ||
| }, | ||
| { | ||
| name: "nok, has query and fragment", | ||
| givenOrigins: []string{"http://example.com?foo=bar#section"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: http://example.com?foo=bar#section", | ||
| }, | ||
| { | ||
| name: "nok, has path, query, and fragment", | ||
| givenOrigins: []string{"http://example.com/path?foo=bar#section"}, | ||
| expectErr: "trusted origin can not have path, query, and fragments: http://example.com/path?foo=bar#section", | ||
| }, | ||
| // Edge cases | ||
| { | ||
| name: "nok, empty string", | ||
| givenOrigins: []string{""}, | ||
| expectErr: "trusted origin is missing scheme or host: ", | ||
| }, | ||
| { | ||
| name: "nok, whitespace only", | ||
| givenOrigins: []string{" "}, | ||
| expectErr: "trusted origin is missing scheme or host: ", | ||
| }, | ||
| { | ||
| name: "nok, multiple origins - first invalid", | ||
| givenOrigins: []string{"example.com", "http://valid.com"}, | ||
| expectErr: "trusted origin is missing scheme or host: example.com", | ||
| }, | ||
| { | ||
| name: "nok, multiple origins - middle invalid", | ||
| givenOrigins: []string{"http://valid1.com", "invalid.com", "http://valid2.com"}, | ||
| expectErr: "trusted origin is missing scheme or host: invalid.com", | ||
| }, | ||
| { | ||
| name: "nok, multiple origins - last invalid", | ||
| givenOrigins: []string{"http://valid.com", "invalid.com"}, | ||
| expectErr: "trusted origin is missing scheme or host: invalid.com", | ||
| }, | ||
| // Different "what" parameter | ||
| { | ||
| name: "nok, custom what parameter - missing scheme", | ||
| givenOrigins: []string{"example.com"}, | ||
| givenWhat: "allowed origin", | ||
| expectErr: "allowed origin is missing scheme or host: example.com", | ||
| }, | ||
| { | ||
| name: "nok, custom what parameter - has path", | ||
| givenOrigins: []string{"http://example.com/path"}, | ||
| givenWhat: "cors origin", | ||
| expectErr: "cors origin can not have path, query, and fragments: http://example.com/path", | ||
| }, | ||
| } | ||
| for _, tc := range testCases { | ||
| t.Run(tc.name, func(t *testing.T) { | ||
| what := tc.givenWhat | ||
| if what == "" { | ||
| what = "trusted origin" | ||
| } | ||
| err := validateOrigins(tc.givenOrigins, what) | ||
| if tc.expectErr != "" { | ||
| assert.EqualError(t, err, tc.expectErr) | ||
| } else { | ||
| assert.NoError(t, err) | ||
| } | ||
| }) | ||
| } | ||
| } |
+25
-0
@@ -9,3 +9,5 @@ // SPDX-License-Identifier: MIT | ||
| "crypto/rand" | ||
| "fmt" | ||
| "io" | ||
| "net/url" | ||
| "strings" | ||
@@ -105,1 +107,24 @@ "sync" | ||
| } | ||
| func validateOrigins(origins []string, what string) error { | ||
| for _, o := range origins { | ||
| if err := validateOrigin(o, what); err != nil { | ||
| return err | ||
| } | ||
| } | ||
| return nil | ||
| } | ||
| func validateOrigin(origin string, what string) error { | ||
| u, err := url.Parse(origin) | ||
| if err != nil { | ||
| return fmt.Errorf("can not parse %s: %w", what, err) | ||
| } | ||
| if u.Scheme == "" || u.Host == "" { | ||
| return fmt.Errorf("%s is missing scheme or host: %s", what, origin) | ||
| } | ||
| if u.Path != "" || u.RawQuery != "" || u.Fragment != "" { | ||
| return fmt.Errorf("%s can not have path, query, and fragments: %s", what, origin) | ||
| } | ||
| return nil | ||
| } |