From b1eb82e1b988d140af6379926aed0d43aaaf31ef Mon Sep 17 00:00:00 2001 From: Sangbum Kim Date: Fri, 25 Mar 2022 06:06:43 +0900 Subject: [PATCH] initial commit --- .gitignore | 366 ++++++++++++++++++++++++++ LICENSE | 22 ++ README.md | 24 ++ client/client.go | 76 ++++++ client/index.go | 2 + client/round_tripper.go | 53 ++++ cookie.go | 53 ++++ cookie_fasthttp.go | 165 ++++++++++++ encodings.go | 88 +++++++ go.mod | 29 ++ index.go | 2 + json.go | 25 ++ middleware/accesslog_fasthttp.go | 325 +++++++++++++++++++++++ middleware/cache_control.go | 34 +++ middleware/csrf_fasthttp.go | 185 +++++++++++++ middleware/index.go | 2 + middleware/internal/secret_box_dto.go | 56 ++++ middleware/internal/session_dto.go | 134 ++++++++++ middleware/internal/session_pool.go | 39 +++ middleware/secret_box_fasthttp.go | 46 ++++ middleware/session_fasthttp.go | 192 ++++++++++++++ routing/context.go | 71 +++++ routing/helper.go | 25 ++ routing/iface.go | 14 + routing/index.go | 2 + routing/registry.go | 99 +++++++ routing/reverse_route.go | 247 +++++++++++++++++ serve/fasthttp.go | 210 +++++++++++++++ serve/iface.go | 27 ++ serve/index.go | 2 + serve/link.go | 40 +++ serve/static_fasthttp.go | 153 +++++++++++ serve/stub.s | 4 + status_resp.go | 263 ++++++++++++++++++ 34 files changed, 3075 insertions(+) create mode 100755 .gitignore create mode 100755 LICENSE create mode 100755 README.md create mode 100644 client/client.go create mode 100644 client/index.go create mode 100644 client/round_tripper.go create mode 100644 cookie.go create mode 100644 cookie_fasthttp.go create mode 100644 encodings.go create mode 100644 go.mod create mode 100644 index.go create mode 100644 json.go create mode 100644 middleware/accesslog_fasthttp.go create mode 100644 middleware/cache_control.go create mode 100644 middleware/csrf_fasthttp.go create mode 100644 middleware/index.go create mode 100644 middleware/internal/secret_box_dto.go create mode 100644 middleware/internal/session_dto.go create mode 100644 middleware/internal/session_pool.go create mode 100644 middleware/secret_box_fasthttp.go create mode 100644 middleware/session_fasthttp.go create mode 100644 routing/context.go create mode 100644 routing/helper.go create mode 100644 routing/iface.go create mode 100644 routing/index.go create mode 100644 routing/registry.go create mode 100644 routing/reverse_route.go create mode 100644 serve/fasthttp.go create mode 100644 serve/iface.go create mode 100644 serve/index.go create mode 100644 serve/link.go create mode 100644 serve/static_fasthttp.go create mode 100644 serve/stub.s create mode 100644 status_resp.go diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..bff3575 --- /dev/null +++ b/.gitignore @@ -0,0 +1,366 @@ +# Created by https://www.gitignore.io/api/intellij,go,linux,osx,windows,node,python,executable,jetbrains+all,visualstudiocode,compressedarchive,git +# Edit at https://www.gitignore.io/?templates=intellij,go,linux,osx,windows,node,python,executable,jetbrains+all,visualstudiocode,compressedarchive,git + +### CompressedArchive ### + +### Mostly from https://en.wikipedia.org/wiki/List_of_archive_formats + +## Archiving and compression +# Open source file format. Used by 7-Zip. +*.7z +# Mac OS X, restoration on different platforms is possible although not immediate Yes Based on 7z. Preserves Spotlight metadata, resource forks, owner/group information, dates and other data which would be otherwise lost with compression. +*.s7z +# Old archive versions only Proprietary format +*.ace +# A format that compresses and doubly encrypt the data (AES256 and CAS256) avoiding brute force attacks, also hide files in an AFA file. It has two ways to safeguard data integrity and subsequent repair of the file if has an error (repair with AstroA2P (online) or Astrotite (offline)). +*.afa +# A mainly Korean format designed for very large archives. +*.alz +# Android application package (variant of JAR file format). +*.apk +# ?? +*.arc +# Originally DOS, now multiple +*.arj +# Open archive format, used by B1 Free Archiver (http://dev.b1.org/standard/archive-format.html) +*.b1 +# Binary Archive with external header +*.ba +# Proprietary format from the ZipTV Compression Components +*.bh +# The Microsoft Windows native archive format, which is also used by many commercial installers such as InstallShield and WISE. +*.cab +# Originally DOS, now DOS and Windows Created by Yaakov Gringeler; released last in 2003 (Compressia 1.0.0.1 beta), now apparently defunct. Free trial of 30 days lets user create and extract archives; after that it is possible to extract, but not to create. +*.car +# Open source file format. +*.cfs +# Compact Pro archive, a common archiver used on Mac platforms until about Mac OS 7.5.x. Competed with StuffIt; now obsolete. +*.cpt +# Windows, Unix-like, Mac OS X Open source file format. Files are compressed individually with either gzip, bzip2 or lzo. +*.dar +# DiskDoubler Mac OS obsolete +*.dd +# ?? +*.dgc +# Apple Disk Image upports "Internet-enabled" disk images, which, once downloaded, are automatically decompressed, mounted, have the contents extracted, and thrown away. Currently, Safari is the only browser that supports this form of extraction; however, the images can be manually extracted as well. This format can also be password-protected or encrypted with 128-bit or 256-bit AES encryption. +*.dmg +# Enterprise Java Archive archive +*.ear +# ETSoft compressed archive +*.egg +# The predecessor of DGCA. +*.gca +# Originally DOS Yes, but may be covered by patents DOS era format; uses arithmetic/Markov coding +*.ha +# MS Windows HKI +*.hki +# Produced by ICEOWS program. Excels at text file compression. +*.ice +# Java archive, compatible with ZIP files +*.jar +# Open sourced archiver with compression using the PAQ family of algorithms and optional encryption. +*.kgb +# Originally DOS, now multiple Multiple Yes The standard format on Amiga. +*.lzh +*.lha +# Archiver originally used on The Amiga. Now copied by Microsoft to use in their .cab and .chm files. +*.lzx +# file format from NoGate Consultings, a rival from ARC-Compressor. +*.pak +# A disk image archive format that supports several compression methods as well as splitting the archive into smaller pieces. +*.partimg +# An experimental open source packager (http://mattmahoney.net/dc) +*.paq* +# Open source archiver supporting authenticated encryption, volume spanning, customizable object level and volume level integrity checks (form CRCs to SHA-512 and Whirlpool hashes), fast deflate based compression +*.pea +# The format from the PIM - a freeware compression tool by Ilia Muraviev. It uses an LZP-based compression algorithm with set of filters for executable, image and audio files. +*.pim +# PackIt Mac OS obsolete +*.pit +# Used for data in games written using the Quadruple D library for Delphi. Uses byte pair compression. +*.qda +# A proprietary archive format, second in popularity to .zip files. +*.rar +# The format from a commercial archiving package. Odd among commercial packages in that they focus on incorporating experimental algorithms with the highest possible compression (at the expense of speed and memory), such as PAQ, PPMD and PPMZ (PPMD with unlimited-length strings), as well as a proprietary algorithms. +*.rk +# Self Dissolving ARChive Commodore 64, Commodore 128 Commodore 64, Commodore 128 Yes SDAs refer to Self Dissolving ARC files, and are based on the Commodore 64 and Commodore 128 versions of ARC, originally written by Chris Smeets. While the files share the same extension, they are not compatible between platforms. That is, an SDA created on a Commodore 64 but run on a Commodore 128 in Commodore 128 mode will crash the machine, and vice versa. The intended successor to SDA is SFX. +*.sda +# A pre-Mac OS X Self-Extracting Archive format. StuffIt, Compact Pro, Disk Doubler and others could create .sea files, though the StuffIt versions were the most common. +*.sea +# Scifer Archive with internal header +*.sen +# Commodore 64, Commodore 128 SFX is a Self Extracting Archive which uses the LHArc compression algorithm. It was originally developed by Chris Smeets on the Commodore platform, and runs primarily using the CS-DOS extension for the Commodore 128. Unlike its predecessor SDA, SFX files will run on both the Commodore 64 and Commodore 128 regardless of which machine they were created on. +*.sfx +# An archive format designed for the Apple II series of computers. The canonical implementation is ShrinkIt, which can operate on disk images as well as files. Preferred compression algorithm is a combination of RLE and 12-bit LZW. Archives can be manipulated with the command-line NuLib tool, or the Windows-based CiderPress. +*.shk +# A compression format common on Apple Macintosh computers. The free StuffIt Expander is available for Windows and OS X. +*.sit +# The replacement for the .sit format that supports more compression methods, UNIX file permissions, long file names, very large files, more encryption options, data specific compressors (JPEG, Zip, PDF, 24-bit image, MP3). The free StuffIt Expander is available for Windows and OS X. +*.sitx +# A royalty-free compressing format +*.sqx +# The "tarball" format combines tar archives with a file-based compression scheme (usually gzip). Commonly used for source and binary distribution on Unix-like platforms, widely available elsewhere. +*.tar.gz +*.tgz +*.tar.Z +*.tar.bz2 +*.tbz2 +*.tar.lzma +*.tlz +# UltraCompressor 2.3 was developed to act as an alternative to the then popular PKZIP application. The main feature of the application is its ability to create large archives. This means that compressed archives with the UC2 file extension can hold almost 1 million files. +*.uc +*.uc0 +*.uc2 +*.ucn +*.ur2 +*.ue2 +# Based on PAQ, RZM, CSC, CCM, and 7zip. The format consists of a PAQ, RZM, CSC, or CCM compressed file and a manifest with compression settings stored in a 7z archive. +*.uca +# A high compression rate archive format originally for DOS. +*.uha +# Web Application archive (Java-based web app) +*.war +# File-based disk image format developed to deploy Microsoft Windows. +*.wim +# XAR +*.xar +# Native format of the Open Source KiriKiri Visual Novel engine. Uses combination of block splitting and zlib compression. The filenames and pathes are stored in UTF-16 format. For integrity check, the Adler-32 hashsum is used. For many commercial games, the files are encrypted (and decoded on runtime) via so-called "cxdec" module, which implements xor-based encryption. +*.xp3 +# Yamazaki zipper archive. Compression format used in DeepFreezer archiver utility created by Yamazaki Satoshi. Read and write support exists in TUGZip, IZArc and ZipZag +*.yz1 +# The most widely used compression format on Microsoft Windows. Commonly used on Macintosh and Unix systems as well. +*.zip +*.zipx +# application/x-zoo zoo Multiple Multiple Yes +*.zoo +# Journaling (append-only) archive format with rollback capability. Supports deduplication and incremental update based on last-modified dates. Multi-threaded. Compresses in LZ77, BWT, and context mixing formats. Open source. +*.zpaq +# Archiver with a compression algorithm based on the Burrows-Wheeler transform method. +*.zz + + +### Executable ### +*.app +*.bat +*.cgi +*.com +*.exe +*.gadget +*.pif +*.vb +*.wsf + +### Git ### +# Created by git for backups. To disable backups in Git: +# $ git config --global mergetool.keepBackup false +*.orig + +# Created by git when using merge tools for conflicts +*.BACKUP.* +*.BASE.* +*.LOCAL.* +*.REMOTE.* +*_BACKUP_*.txt +*_BASE_*.txt +*_LOCAL_*.txt +*_REMOTE_*.txt + +### Go ### +# Binaries for programs and plugins +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +### Go Patch ### +/vendor/ +/Godeps/ + +### Intellij ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### Intellij Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +.idea/sonarlint + +### JetBrains+all Patch ### +# Ignores the whole .idea folder and all .iml files +# See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 + +.idea/ + +# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 + +*.iml +modules.xml +.idea/misc.xml +*.ipr + +# Sonarlint plugin + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### OSX ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# End of https://www.gitignore.io/api/intellij,go,linux,osx,windows,node,python,executable,jetbrains+all,visualstudiocode,compressedarchive,git + +go.sum +/vendor diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000..2e02761 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +The BSD 3-Clause License + +Copyright (c) 2022 Sangbum Kim. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided +that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions + and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and + the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or + promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100755 index 0000000..70bf1f8 --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +# eighty + +[![GoDoc](https://godoc.org/amuz.es/src/go/eighty?status.png)](http://godoc.org/amuz.es/src/go/eighty) +[![Go Report](https://goreportcard.com/badge/spi-ca/eighty)](http://goreportcard.com/report/spi-ca/eighty) + +## Description + +net.http and fasthttp related utility functions. + +## Requirements + +Go 1.5 or above. + +## Installation + +Run the following command to install the package: + +``` +go get amuz.es/src/go/eighty +``` + +## Getting Started + +TBD diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..b10145a --- /dev/null +++ b/client/client.go @@ -0,0 +1,76 @@ +package client + +import ( + "errors" + "io" + "net/http" + "net/url" + "time" +) + +var ( + httpCannotRedirectError = errors.New("this client cannot redirect") + disableRedirect = func(_ *http.Request, _ []*http.Request) error { + return httpCannotRedirectError + } + limitedRedirect = func(_ *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil + } +) + +// Client is an http.Client with some tunable parameters. +type Client interface { + http.RoundTripper + HttpClient() *http.Client + roundTripper() http.RoundTripper + Do(req *http.Request) (*http.Response, error) + Get(url string) (resp *http.Response, err error) + Head(url string) (resp *http.Response, err error) + Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) + PostForm(url string, data url.Values) (resp *http.Response, err error) +} + +type wrappedClient struct { + http.Client +} + +func (cli *wrappedClient) HttpClient() *http.Client { + return &cli.Client +} +func (cli *wrappedClient) roundTripper() http.RoundTripper { + return cli.Transport +} + +func (cli *wrappedClient) RoundTrip(req *http.Request) (*http.Response, error) { + return cli.Client.Transport.RoundTrip(req) +} + +// NewClient returns a Client interface that has some tunable parameters. +func NewClient( + keepaliveDuration time.Duration, + connectTimeout time.Duration, + responseHeaderTimeout time.Duration, + idleConnectionTimeout time.Duration, + maxIdleConnections int, + redirectSupport bool, + serverName string, +) Client { + + var redirectChecker func(*http.Request, []*http.Request) error + if redirectSupport { + redirectChecker = limitedRedirect + } else { + redirectChecker = disableRedirect + } + + return &wrappedClient{ + Client: http.Client{ + Transport: NewRoundTripper(keepaliveDuration, connectTimeout, responseHeaderTimeout, idleConnectionTimeout, maxIdleConnections, serverName), + CheckRedirect: redirectChecker, + Jar: nil, + }, + } +} diff --git a/client/index.go b/client/index.go new file mode 100644 index 0000000..297fdf3 --- /dev/null +++ b/client/index.go @@ -0,0 +1,2 @@ +// Package client provides some customizable http client. +package client diff --git a/client/round_tripper.go b/client/round_tripper.go new file mode 100644 index 0000000..a3df06a --- /dev/null +++ b/client/round_tripper.go @@ -0,0 +1,53 @@ +package client + +import ( + "net" + "net/http" + "time" +) + +const ( + userAgentHeader = "User-Agent" +) + +type predefinedHeaderTransport struct { + useragentName string + http.Transport +} + +func (pht *predefinedHeaderTransport) RoundTrip(req *http.Request) (res *http.Response, err error) { + req.Close = pht.DisableKeepAlives + req.Header.Set(userAgentHeader, pht.useragentName) + res, err = pht.Transport.RoundTrip(req) + return +} + +// NewRoundTripper returns a http.RoundTripper that has some tunable parameters. +func NewRoundTripper( + keepaliveDuration time.Duration, + connectTimeout time.Duration, + responseHeaderTimeout time.Duration, + idleConnectionTimeout time.Duration, + maxIdleConnections int, + serverName string, +) http.RoundTripper { + + keepaliveDisabled := keepaliveDuration == 0 + dialer := &net.Dialer{ + Timeout: connectTimeout, + KeepAlive: keepaliveDuration, + } + + return &predefinedHeaderTransport{ + useragentName: serverName, + Transport: http.Transport{ + DisableKeepAlives: keepaliveDisabled, + DisableCompression: true, + MaxIdleConnsPerHost: maxIdleConnections, + DialContext: dialer.DialContext, + MaxIdleConns: maxIdleConnections, + IdleConnTimeout: idleConnectionTimeout, + ResponseHeaderTimeout: responseHeaderTimeout, + }, + } +} diff --git a/cookie.go b/cookie.go new file mode 100644 index 0000000..8a0cad1 --- /dev/null +++ b/cookie.go @@ -0,0 +1,53 @@ +package eighty + +import ( + "net/http" + "time" +) + +var ( + oldTime = time.Unix(0, 0) +) + +// SetCookieValue is a useful cookie generator for net.http. +func SetCookieValue(key string, expireDuration time.Duration, sessionSecure bool) func(http.ResponseWriter, string, string) { + return func(w http.ResponseWriter, host, newCookieValue string) { + if len(newCookieValue) > 0 { + http.SetCookie(w, + &http.Cookie{ + Name: key, + Value: newCookieValue, + Path: "/", + Domain: host, + Expires: time.Now().Add(expireDuration), + Secure: sessionSecure, + SameSite: http.SameSiteLaxMode, + MaxAge: int(expireDuration.Seconds()), + HttpOnly: true, + }, + ) + } else { + http.SetCookie(w, + &http.Cookie{ + Name: key, + Value: "_", + Path: "/", + Domain: host, + Expires: oldTime, + Secure: sessionSecure, + SameSite: http.SameSiteLaxMode, + MaxAge: -1, + HttpOnly: true, + }, + ) + } + } +} + +// GetCookieValue is the simple cookie getter. +func GetCookieValue(req *http.Request, name string) (cookieValue string) { + if cookie, _ := req.Cookie(name); cookie != nil { + cookieValue = cookie.Value + } + return +} diff --git a/cookie_fasthttp.go b/cookie_fasthttp.go new file mode 100644 index 0000000..a9b775f --- /dev/null +++ b/cookie_fasthttp.go @@ -0,0 +1,165 @@ +package eighty + +import ( + "amuz.es/src/go/misc/networking" + "amuz.es/src/go/misc/strutil" + "github.com/valyala/fasthttp" + "log" + "net" + "strings" + "time" +) + +type ( + cookieWriterFasthttpImpl struct { + key string + expireDuration time.Duration + sessionSecure bool + } + + CookieWriterFasthttp func(*fasthttp.Response, []byte, string) +) + +// NewCookieWriter is a useful cookie generator for fasthttp. +func NewCookieWriter(key string, expireDuration time.Duration, secured bool) CookieWriterFasthttp { + return (&cookieWriterFasthttpImpl{ + key: key, + expireDuration: expireDuration, + sessionSecure: secured, + }).Write +} +func (cw *cookieWriterFasthttpImpl) validateCookiePathByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != ';' +} + +func (cw *cookieWriterFasthttpImpl) validateCookieValueByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' +} + +func (cw *cookieWriterFasthttpImpl) validateCookieDomain(v []byte) (valid bool) { + // isCookieDomainName + if len(v) == 0 { + return false + } + if len(v) > 255 { + return false + } + + if v[0] == '.' { + // A cookie a domain attribute may start with a leading dot. + v = v[1:] + } + var last byte = '.' + partlen := 0 + for i := 0; i < len(v); i++ { + c := v[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // No '_' allowed here (in contrast to package net). + valid = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // Byte before dash cannot be dot. + if last == '.' { + return false + } + partlen++ + case c == '.': + // Byte before dot cannot be dot, dash. + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + + if last == '-' || partlen > 63 { + return false + } else if valid { + // isCookieDomainName + return + } + // isCookieValidIp + addr := networking.ParseIPv4(v) + return addr != nil && + !addr.Equal(net.IPv4bcast) && + !addr.IsUnspecified() && + !addr.IsMulticast() && + !addr.IsLinkLocalUnicast() +} + +func (cw *cookieWriterFasthttpImpl) sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { + ok := true + for i := 0; i < len(v); i++ { + if valid(v[i]) { + continue + } + log.Printf("invalid byte %q in %s; dropping invalid bytes", v[i], fieldName) + ok = false + break + } + if ok { + return v + } + var build strings.Builder + for i := 0; i < len(v); i++ { + if b := v[i]; valid(b) { + build.WriteByte(b) + } + } + return build.String() +} + +func (cw *cookieWriterFasthttpImpl) sanitizeCookiePath(v string) string { + return cw.sanitizeOrWarn("Cookie.Path", cw.validateCookiePathByte, v) +} + +func (cw *cookieWriterFasthttpImpl) sanitizeCookieValue(v string) string { + v = cw.sanitizeOrWarn("Cookie.Value", cw.validateCookieValueByte, v) + if len(v) == 0 { + return v + } + if strings.IndexByte(v, ' ') >= 0 || strings.IndexByte(v, ',') >= 0 { + return `"` + v + `"` + } + return v +} + +func (cw *cookieWriterFasthttpImpl) Write(w *fasthttp.Response, host []byte, newCookieValue string) { + cookie := fasthttp.AcquireCookie() + defer fasthttp.ReleaseCookie(cookie) + + cookie.SetKey(cw.key) + cookie.SetPath(cw.sanitizeCookiePath("/")) + + if len(host) > 0 { + if cw.validateCookieDomain(host) { + cookie.SetDomainBytes(host) + } else { + log.Printf("invalid Cookie.Domain %s; dropping domain attribute", strutil.B2S(host)) + } + } + cookie.SetSecure(cw.sessionSecure) + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + cookie.SetHTTPOnly(true) + + if len(newCookieValue) > 0 { + cookie.SetValue(cw.sanitizeCookieValue(newCookieValue)) + cookie.SetExpire(time.Now().Add(cw.expireDuration)) + cookie.SetMaxAge(int(cw.expireDuration.Seconds())) + } else { + cookie.SetValue("-") + cookie.SetExpire(oldTime) + cookie.SetMaxAge(-1) + } + w.Header.SetCookie(cookie) +} diff --git a/encodings.go b/encodings.go new file mode 100644 index 0000000..fedf026 --- /dev/null +++ b/encodings.go @@ -0,0 +1,88 @@ +package eighty + +import ( + "amuz.es/src/go/misc/strutil" + "github.com/valyala/fasthttp" + "mime" + "strings" +) + +// Collection of predefined request header names. +const ( + ContentTypeHeader = "Content-Type" + ContentLengthHeader = "Content-Length" + EtagHeader = "Etag" + UserAgentHeader = "User-Agent" + LastModifiedHeader = "Last-Modified" + ExpiresHeader = "Expires" + CacheControlHeader = "Cache-Control" + IfModifiedSince = "If-Modified-Since" + IfNoneMatch = "If-None-Match" + Server = "Server" + VaryHeader = "Vary" + ForwardedForIPHeader = "X-Forwarded-For" +) + +// Collection of predefined response header names. +const ( + RetryAfterHeader = "Retry-After" + LocationHeader = "Location" + FrameOptionHeader = "X-Frame-Options" + ContentTypeOptionHeader = "X-Content-Type-Options" + XssProtectionHeader = "X-XSS-Protection" + XCsrfToken = "X-CSRF-Token" + XForwardedProto = "X-Forwarded-Proto" +) + +// Collection of predefined cache header values. +const ( + CacheControlNoCache = "private, no-cache, no-store, no-transform, max-age=0, must-revalidate" + ExpiresNone = "0" +) + +// Collection of predefined mime types. +var ( + HtmlContentUTF8Type = []string{"text/html; charset=utf-8"} + HtmlContentType = []string{"text/html"} + TextContentType = []string{"text/text"} + TextContentUTF8Type = []string{"text/text; charset=utf-8"} + UrlencodeContentUTF8Type = []string{"application/x-www-form-urlencoded; charset=utf-8"} + UrlencodeContentType = []string{"application/x-www-form-urlencoded"} + JsonContentUTF8Type = []string{"application/json; charset=utf-8"} + JsonContentType = []string{"application/json"} +) + +// Collection of predefined CSRF header values. +var ( + FrameOptionDeny = []string{"DENY"} + FrameOptionSameOrigin = []string{"SAMEORIGIN"} + + ContentTypeOptionNoSniffing = []string{"nosniff"} + XssProtectionBlocking = []string{"1; mode=block"} +) + +// Collection of predefined http method names. +var ( + MethodHEAD = []byte("HEAD") + MethodGET = []byte("GET") + MethodPOST = []byte("POST") +) + +// HasContentTypeFasthttp is a simple checker that checks if an incoming request satisfies a given mime-type. +func HasContentTypeFasthttp(r *fasthttp.Request, mimetype string) bool { + contentType := strutil.B2S(r.Header.ContentType()) + if len(contentType) == 0 { + return mimetype == "application/octet-stream" + } + + for _, v := range strings.Split(contentType, ",") { + t, _, err := mime.ParseMediaType(v) + if err != nil { + break + } + if t == mimetype { + return true + } + } + return false +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e588862 --- /dev/null +++ b/go.mod @@ -0,0 +1,29 @@ +module amuz.es/src/go/eighty + +go 1.18 + +require ( + amuz.es/src/go/logging v1.0.0 + amuz.es/src/go/misc v1.0.1 + github.com/fasthttp/router v1.4.6 + github.com/valyala/fasthttp v1.34.0 + gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40 +) + +require ( + github.com/andybalholm/brotli v1.0.4 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.15.0 // indirect + github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible // indirect + github.com/lestrrat-go/strftime v1.0.5 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/savsgio/gotils v0.0.0-20211223103454-d0aaa54c5899 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.8.0 // indirect + go.uber.org/zap v1.21.0 // indirect + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect +) diff --git a/index.go b/index.go new file mode 100644 index 0000000..a45b018 --- /dev/null +++ b/index.go @@ -0,0 +1,2 @@ +// Package eighty is net.http and fasthttp related utility functions. +package eighty // import "amuz.es/src/go/eighty" diff --git a/json.go b/json.go new file mode 100644 index 0000000..6d6bacd --- /dev/null +++ b/json.go @@ -0,0 +1,25 @@ +package eighty + +import ( + "amuz.es/src/go/misc" + "github.com/valyala/fasthttp" +) + +const ( + jsonMimeType = "application/json; charset=utf-8" +) + +// DumpJSONFasthttp is a simple JSON renderer for the fasthttp. +func DumpJSONFasthttp(ctx *fasthttp.RequestCtx, code int, serializable any) { + stream := misc.JSONCodec.BorrowStream(nil) + defer misc.JSONCodec.ReturnStream(stream) + + if stream.WriteVal(serializable); stream.Error != nil { + panic(stream.Error) + } else if _, err := ctx.Write(stream.Buffer()); err != nil { + panic(err) + } + + ctx.SetContentType(jsonMimeType) + ctx.SetStatusCode(code) +} diff --git a/middleware/accesslog_fasthttp.go b/middleware/accesslog_fasthttp.go new file mode 100644 index 0000000..63f4861 --- /dev/null +++ b/middleware/accesslog_fasthttp.go @@ -0,0 +1,325 @@ +package middleware + +import ( + "amuz.es/src/go/eighty" + "amuz.es/src/go/eighty/routing" + "amuz.es/src/go/logging" + "amuz.es/src/go/misc/q" + "amuz.es/src/go/misc/strutil" + "bytes" + "context" + "github.com/valyala/fasthttp" + "io" + "io/ioutil" + "net" + "runtime" + "strconv" + "strings" + "sync" + "time" +) + +const ( + dateFormat = "02/Jan/2006:15:04:05 -0700" +) + +type accessLogMiddleware struct { + writer io.WriteCloser + + apiUrlPrefix string + + logInChan chan<- string + logOutChan <-chan string + + logger logging.Logger + logWaiter sync.Mutex + waitGroup sync.WaitGroup + ctx context.Context + closer func() + + errorViewTemplateRenderer eighty.PageRenderer +} + +func (m *accessLogMiddleware) Handle(h routing.Router) routing.Router { + return func(ctx *fasthttp.RequestCtx) { + m.waitGroup.Add(1) + defer m.waitGroup.Done() + // access log 기록 + defer m.recordAccess()(ctx) + // 내부 panic 해소 + defer m.handlePanic(ctx) + m.writeBasicHeader(&ctx.Response) + h(ctx) + } +} + +// source returns a space-trimmed slice of the n'th line. +func (m *accessLogMiddleware) source(buf *strings.Builder, lines [][]byte, n int) { + n-- // in stack trace, lines are 1-indexed but our array is 0-indexed + if n < 0 || n >= len(lines) { + buf.WriteString("???") + } else { + buf.Write(lines[n]) + } +} + +// function returns, if possible, the name of the function containing the PC. +func (m *accessLogMiddleware) function(buf *strings.Builder, pc uintptr) { + fn := runtime.FuncForPC(pc) + if fn == nil { + buf.WriteString("???") + } + name := fn.Name() + // The name include the path name to the package, which is unnecessary + // since the file name is already included. Plus, it has center dots. + // That is, we see + // runtime/debug.*T·ptrmethod + // and want + // *T.ptrmethod + // Also the package path might contains dot (e.g. code.google.com/...), + // so first eliminate the path prefix + if lastslash := strings.LastIndexByte(name, '/'); lastslash >= 0 { + name = name[lastslash+1:] + } + if period := strings.IndexByte(name, '.'); period >= 0 { + name = name[period+1:] + } + buf.WriteString(strings.Replace(name, "·", ".", -1)) +} + +// 라우팅 로직에서 panic이 발생 했을경우 해당 스택을 보여준다. +// stack returns a nicely formated stack frame, skipping skip frames +func (m *accessLogMiddleware) getStack(buf *strings.Builder, skip int) { + // As we loop, we open files and read them. These variables record the currently + // loaded file. + var lines [][]byte + var lastFile string + for i := skip; ; i++ { + // Skip the expected number of frames + pc, file, line, ok := runtime.Caller(i) + if !ok { + break + } else if i > skip { + buf.WriteByte('\n') + } + if paths := strings.SplitN(file, "src/", 2); len(paths) == 1 { + // Print this much at least. If we can't find the source, it won't show. + //_, _ = fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc) + buf.WriteString(file) + } else if vendors := strings.SplitN(paths[1], "vendor/", 2); len(vendors) == 1 { + // Print this much at least. If we can't find the source, it won't show. + //_, _ = fmt.Fprintf(buf, "%s:%d (0x%x)\n", paths[1], line, pc) + buf.WriteString(paths[1]) + } else { + // Print this much at least. If we can't find the source, it won't show. + //_, _ = fmt.Fprintf(buf, "%s:%d (0x%x)\n", vendors[1], line, pc) + buf.WriteString(vendors[1]) + } + buf.WriteByte(':') + buf.WriteString(strconv.FormatInt(int64(line), 10)) + buf.WriteString(" (0x") + buf.WriteString(strconv.FormatInt(int64(pc), 16)) + buf.WriteString(")\n") + + // ----- + //_, _ = fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line)) + buf.WriteByte('\t') + m.function(buf, pc) + buf.WriteString(": ") + if file == lastFile { + buf.WriteString("???") + } else if data, err := ioutil.ReadFile(file); err != nil { + buf.WriteString("???") + } else { + lines = bytes.Split(data, []byte{'\n'}) + lastFile = file + m.source(buf, lines, line) + } + } +} + +func (m *accessLogMiddleware) handlePanic(ctx *fasthttp.RequestCtx) { + panicObj := recover() + if panicObj == nil { + return + } + errorType, err := eighty.WrapHandledError(panicObj) + if err != nil { + var buf strings.Builder + buf.WriteString("PANIC! ") + buf.WriteString(err.Error()) + buf.WriteString("\n--------\nREQUEST\n") + _, _ = ctx.Request.WriteTo(&buf) + buf.WriteString("\n--------\nSTACK\n") + m.getStack(&buf, 3) + buf.WriteString("\n--------") + m.logger.Error(buf.String()) + } + isAPI := bytes.HasPrefix(ctx.RequestURI(), []byte(m.apiUrlPrefix)) + if isAPI { + errorType.RenderAPI(ctx, err) + } else { + errorType.RenderPage(ctx, m.errorViewTemplateRenderer, err) + } +} + +func (m *accessLogMiddleware) writeBasicHeader(w *fasthttp.Response) { + w.Header.Set(eighty.FrameOptionHeader, eighty.FrameOptionSameOrigin[0]) + w.Header.Set(eighty.ContentTypeOptionHeader, eighty.ContentTypeOptionNoSniffing[0]) + w.Header.Set(eighty.XssProtectionHeader, eighty.XssProtectionBlocking[0]) +} + +func (m *accessLogMiddleware) recordAccess() routing.Router { + now := time.Now() + return func(ctx *fasthttp.RequestCtx) { + var ( + dur = time.Since(now) + builder strings.Builder + ) + _, _ = builder.Write(m.remoteAddr(ctx.RemoteAddr(), &ctx.Request)) + _, _ = builder.WriteString(` - - [`) + _, _ = builder.WriteString(now.Format(dateFormat)) + _, _ = builder.WriteString(`] "`) + _, _ = builder.Write(ctx.Method()) + _ = builder.WriteByte(' ') + _, _ = builder.Write(ctx.RequestURI()) + _ = builder.WriteByte(' ') + if ctx.Request.Header.IsHTTP11() { + _, _ = builder.WriteString("HTTP/1.1") + } else { + _, _ = builder.WriteString("HTTP/1.0") + } + _, _ = builder.WriteString(`" `) + _, _ = builder.Write(strutil.FormatIntToBytes(ctx.Response.StatusCode())) + _ = builder.WriteByte(' ') + _, _ = builder.Write(strutil.FormatIntToBytes(ctx.Response.Header.ContentLength())) + _, _ = builder.WriteString(` "`) + _, _ = builder.Write(ctx.Request.Header.Referer()) + _, _ = builder.WriteString(`" "`) + _, _ = builder.Write(ctx.Request.Header.UserAgent()) + _, _ = builder.WriteString(`" `) + _, _ = builder.Write(strutil.FormatIntToBytes(int(dur.Nanoseconds() / time.Millisecond.Nanoseconds()))) + _ = builder.WriteByte(' ') + _, _ = builder.Write(ctx.Request.Host()) + _ = builder.WriteByte('\n') + + select { + case <-m.ctx.Done(): + m.logger.Error("cannot accesslog record: ", builder.String()) + default: + m.logInChan <- builder.String() + } + } +} + +func (m *accessLogMiddleware) Close() { + defer m.writer.Close() + m.closer() + close(m.logInChan) + m.logWaiter.Lock() + defer m.logWaiter.Unlock() +} + +func (m *accessLogMiddleware) lineByLineWriter() { + var ( + ticker = time.NewTicker(200 * time.Millisecond) + maxsz = 1024 * 1024 + sz = 0 + rcvsz = 0 + buf = make([]byte, maxsz) + flusher = func() { + if sz > 0 { + if _, err := m.writer.Write(buf[:sz]); err != nil { + m.logger.Error("cannot write accesslog chunk : ", err) + } + //reset + sz = 0 + } + } + ) + + m.logWaiter.Lock() + defer func() { + defer m.logWaiter.Unlock() + ticker.Stop() + flusher() + }() + for { + select { + case logItem, ok := <-m.logOutChan: + //or do the next job + if !ok { + return + } + rcvsz = len(logItem) + if maxsz < sz+rcvsz { + flusher() + } + if rcvsz > 0 { + // append + copy(buf[sz:], logItem) + sz += rcvsz + } + case <-ticker.C: + // if deadline exceeded write + flusher() + } + } +} + +// strip port from addresses with hostname, ipv4 or ipv6 +func (m *accessLogMiddleware) stripPort(address string) string { + if h, _, err := net.SplitHostPort(address); err == nil { + return h + } + + return address +} + +// The remote address of the client. When the 'X-Forwarded-For' +// header is set, then it is used instead. +func (m *accessLogMiddleware) remoteAddr(remoteAddr net.Addr, r *fasthttp.Request) (ret []byte) { + if ret = r.Header.Peek(eighty.ForwardedForIPHeader); ret == nil { + ret = []byte(remoteAddr.String()) + } + return +} + +func (m *accessLogMiddleware) remoteHost(remoteAddr net.Addr, r *fasthttp.Request) string { + a := m.remoteAddr(remoteAddr, r) + h := m.stripPort(strutil.B2S(a)) + if h != "" { + return h + } + + return "-" +} + +// AccessLogMiddleware returns a routing.Middleware that handles error handling and access logging. +func AccessLogMiddleware( + apiUrlPrefix string, + logWriter io.WriteCloser, + templateRenderer eighty.PageRenderer, + logger logging.Logger) (handler routing.Middleware, closer func(), err error) { + ctx, canceler := context.WithCancel(context.Background()) + inchan, outchan := q.NewStringQueue() + impl := &accessLogMiddleware{ + writer: logWriter, + apiUrlPrefix: apiUrlPrefix, + logInChan: inchan, + logOutChan: outchan, + logger: logger, + ctx: ctx, + errorViewTemplateRenderer: templateRenderer, + } + impl.closer = func() { + if ctx.Err() == nil { + canceler() + } + impl.waitGroup.Wait() + } + + go impl.lineByLineWriter() + + return impl.Handle, impl.Close, nil +} diff --git a/middleware/cache_control.go b/middleware/cache_control.go new file mode 100644 index 0000000..86b83cb --- /dev/null +++ b/middleware/cache_control.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "amuz.es/src/go/eighty" + "net/http" + "strconv" + "time" +) + +// CacheControlFunc returns a func(next http.Handler) http.Handler that handles cache control header. +func CacheControlFunc(debug bool, startupTime time.Time) func(next http.Handler) http.Handler { + baseVersion := startupTime.Unix() + return func(next http.Handler) http.Handler { + if debug { + return next + } else { + fn := func(w http.ResponseWriter, r *http.Request) { + defer next.ServeHTTP(w, r) + if receivedVersion, err := strconv.ParseInt(r.URL.Query().Get("v"), 10, 64); err == nil && receivedVersion >= baseVersion { + var cacheDuration int64 = 2592000 + if pushedDuration, err := strconv.ParseInt(r.URL.Query().Get("d"), 10, 64); err == nil && pushedDuration > cacheDuration { + cacheDuration = pushedDuration + } + //add header + cacheDurationStr := strconv.FormatInt(cacheDuration, 10) + w.Header().Add(eighty.CacheControlHeader, "public, max-age="+cacheDurationStr) + w.Header().Add(eighty.ExpiresHeader, cacheDurationStr) + } + w.Header().Add(eighty.VaryHeader, "User-Agent") + } + return http.HandlerFunc(fn) + } + } +} diff --git a/middleware/csrf_fasthttp.go b/middleware/csrf_fasthttp.go new file mode 100644 index 0000000..9fcc81f --- /dev/null +++ b/middleware/csrf_fasthttp.go @@ -0,0 +1,185 @@ +package middleware + +import ( + "amuz.es/src/go/eighty" + "amuz.es/src/go/eighty/routing" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "github.com/valyala/fasthttp" + "gitlab.com/NebulousLabs/fastrand" + "io" + "net/http" + "time" +) + +const ( + // the name of CSRF cookie + CsrfCookieName = "csrf_token" + + // the name of CSRF header + csrfContextKey = "csrf" + + csrfTokenLength = 32 +) + +// reasons for CSRF check failures +var ( + csrfSafeMethods = [][]byte{ + []byte(http.MethodGet), + []byte(http.MethodHead), + []byte(http.MethodOptions), + []byte(http.MethodTrace), + } + mockCSRFRouterMiddleware = func(next routing.Router) routing.Router { return next } +) + +type ( + csrfToken struct { + payload string + } + csrfMiddleware struct { + writer eighty.CookieWriterFasthttp + } +) + +// CSRFToken returns a CSRF token in the current request context. +// If the token was not found in the request, zero-value returned. +func CSRFToken(ctx *fasthttp.RequestCtx) (token string) { + if ctx, ok := ctx.UserValue(csrfContextKey).(*csrfToken); ok && ctx != nil { + token = ctx.payload + } + return +} + +// Masks/unmasks the given data *in place* +// with the given key +// Slices must be of the same length, or csrfOneTimePad will panic +func (m *csrfMiddleware) csrfOneTimePad(data, key []byte) { + n := len(data) + if n != len(key) { + panic("Lengths of slices are not equal") + } + + for i := 0; i < n; i++ { + data[i] ^= key[i] + } +} + +func (m *csrfMiddleware) isMethodSafe(s []byte) (safe bool) { + // checks if the given slice contains the given string + for _, v := range csrfSafeMethods { + if safe = subtle.ConstantTimeCompare(v, s) == 1; safe { + break + } + } + return +} + +// A token is generated by returning csrfTokenLength bytes +// from crypto/rand +func (m *csrfMiddleware) generateToken() []byte { + bytes := make([]byte, csrfTokenLength) + + if _, err := io.ReadFull(rand.Reader, bytes); err != nil { + panic(err) + } + + return bytes +} + +func (m *csrfMiddleware) tokenSerializer(data []byte, mask bool) (encoded string) { + if !mask || len(data) != csrfTokenLength { + return + } + + // csrfTokenLength*2 == len(enckey + token) + result := make([]byte, 2*csrfTokenLength) + // the first half of the result is the OTP + // the second half is the masked token itself + key := result[:csrfTokenLength] + token := result[csrfTokenLength:] + copy(token, data) + + // generate the random token + if _, err := io.ReadFull(fastrand.Reader, key); err != nil { + panic(err) + } + m.csrfOneTimePad(token, key) + + return base64.StdEncoding.EncodeToString(result) +} + +func (m *csrfMiddleware) tokenDeserializer(data []byte, unmask bool) (decoded []byte) { + payloadSize := base64.StdEncoding.DecodedLen(len(data)) + if payloadSize != csrfTokenLength*2 { + return + } + + decoded = make([]byte, payloadSize) + n, err := base64.StdEncoding.Decode(decoded, data) + if err != nil || n < payloadSize { + return nil + } + + decoded = decoded[:n] + if unmask { + key := decoded[:csrfTokenLength] + decoded = decoded[csrfTokenLength:] + m.csrfOneTimePad(decoded, key) + } + return +} + +func (m *csrfMiddleware) verifyToken(realToken, sentToken []byte) bool { + realN := len(realToken) + sentN := len(sentToken) + if realN == csrfTokenLength && sentN == csrfTokenLength { + return subtle.ConstantTimeCompare(realToken, sentToken) == 1 + } + return false +} + +func (m *csrfMiddleware) Handle(h routing.Router) routing.Router { + return func(ctx *fasthttp.RequestCtx) { + var ( + realToken []byte + internalToken csrfToken + tokenCreated bool + ) + + if cookieValue := ctx.Request.Header.Cookie(CsrfCookieName); len(cookieValue) > 0 { + realToken = m.tokenDeserializer(cookieValue, false) + } + tokenCreated = len(realToken) != csrfTokenLength + if tokenCreated { + realToken = m.generateToken() + } + internalToken = csrfToken{ + payload: m.tokenSerializer(realToken, true), + } + ctx.SetUserValue(csrfContextKey, &internalToken) + + if m.isMethodSafe(ctx.Method()) { + h(ctx) + } else if sentToken := m.tokenDeserializer(ctx.Request.Header.Peek(eighty.XCsrfToken), true); !m.verifyToken(realToken, sentToken) { + panic(eighty.HandledErrorBadRequest) + } else { + h(ctx) + } + ctx.Response.Header.Set(eighty.VaryHeader, "Cookie") + if tokenCreated { + m.writer(&ctx.Response, ctx.Host(), m.tokenSerializer(realToken, false)) + } + } +} + +// CSRFFunc returns a routing.Middleware that handles CSRF validation logic. +func CSRFFunc(isDebug bool, expire time.Duration, secure bool) (w routing.Middleware) { + if isDebug { + return mockCSRFRouterMiddleware + } + return (&csrfMiddleware{ + writer: eighty.NewCookieWriter(CsrfCookieName, expire, secure), + }).Handle +} diff --git a/middleware/index.go b/middleware/index.go new file mode 100644 index 0000000..b6f6fa3 --- /dev/null +++ b/middleware/index.go @@ -0,0 +1,2 @@ +// Package middleware provides the collection of processing filters for the http request. +package middleware diff --git a/middleware/internal/secret_box_dto.go b/middleware/internal/secret_box_dto.go new file mode 100644 index 0000000..69c840a --- /dev/null +++ b/middleware/internal/secret_box_dto.go @@ -0,0 +1,56 @@ +package internal + +// +//import ( +// "amuz.es/src/mercury/endpoint/misc" +// "encoding/base64" +// "github.com/tinylib/msgp/msgp" +// "io" +// "strings" +//) +// +//type ( +// secretPayload map[string]string +// SecretToolboxImpl struct { +// EncryptWriter func(dst io.Writer) (io.WriteCloser, error) +// DecryptReader func(src io.Reader) (io.Reader, error) +// } +//) +// +//// 인터페이스가 실제구현체랑 호환되는가 +//var _ misc.SecretToolbox = (*SecretToolboxImpl)(nil) +// +//func (tbx *SecretToolboxImpl) Encrypt(data map[string]string) (encrypted string, err error) { +// +// var ( +// payload secretPayload = data +// stringWriter strings.Builder +// ) +// +// encryptor, err := tbx.EncryptWriter(base64.NewEncoder(base64.URLEncoding, &stringWriter)) +// if err != nil { +// return +// } else if err = msgp.Encode(encryptor, payload); err != nil { +// return +// } +// +// _ = encryptor.Close() +// +// encrypted = stringWriter.String() +// return +//} +//func (tbx *SecretToolboxImpl) Decrypt(encrypted string) (data map[string]string, err error) { +// var ( +// stringReader = strings.NewReader(encrypted) +// payload = secretPayload{} +// ) +// +// decryptor, err := tbx.DecryptReader(base64.NewDecoder(base64.URLEncoding, stringReader)) +// if err != nil { +// return +// } else if err = msgp.Decode(decryptor, &payload); err != nil { +// return +// } +// data = payload +// return +//} diff --git a/middleware/internal/session_dto.go b/middleware/internal/session_dto.go new file mode 100644 index 0000000..8bbb473 --- /dev/null +++ b/middleware/internal/session_dto.go @@ -0,0 +1,134 @@ +package internal + +// +//import ( +// misc2 "amuz.es/src/mercury/endpoint/misc" +// "amuz.es/src/mercury/service/models" +// "amuz.es/src/go/misc" +// "amuz.es/src/mercury/util/mycrypt" +// "github.com/pkg/errors" +// "io" +// "log" +// "reflect" +//) +// +//// 인터페이스가 실제구현체랑 호환되는가 +//var _ misc2.Session = (*SessionImpl)(nil) +// +//type SessionImpl struct { +// FBox mycrypt.SecretBox +// FUser models.User +// FToken misc.UUID +// FData map[string]any +// FSyncer func(newdata SessionImpl) error +//} +// +//func (sess *SessionImpl) Id() misc.UUID { return sess.FToken } +//func (sess *SessionImpl) User() models.User { return sess.FUser } +//func (sess *SessionImpl) SetUser(user models.User) { +// sess.Create() +// sess.FUser = user +//} +//func (sess *SessionImpl) Delete() { +// if !sess.FToken.IsZero() { +// sess.FToken = misc.UUID{} +// } +//} +//func (sess *SessionImpl) Create() { +// if sess.FToken.IsZero() { +// sess.FToken.Random() +// } +// if sess.FData == nil { +// sess.FData = make(map[string]any) +// } +//} +//func (sess *SessionImpl) Get(key string) any { +// if sess.FData == nil { +// return nil +// } +// data, ok := sess.FData[key] +// if !ok { +// return nil +// } else { +// return data +// } +//} +//func (sess *SessionImpl) Set(key string, data any) { +// sess.Create() +// reflectVal := reflect.ValueOf(data) +// if reflectVal.IsValid() && reflectVal.Kind() == reflect.Ptr { +// reflectVal = reflectVal.Elem() +// } +// +// switch reflectVal.Kind() { +// case reflect.Struct: +// sess.FData[key] = data +// case reflect.Array, reflect.Map, reflect.Slice, reflect.String: +// if reflectVal.Len() == 0 { +// delete(sess.FData, key) +// } else { +// sess.FData[key] = data +// } +// case reflect.Bool: +// if !reflectVal.Bool() { +// delete(sess.FData, key) +// } else { +// sess.FData[key] = data +// } +// case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: +// if reflectVal.Int() == 0 { +// delete(sess.FData, key) +// } else { +// sess.FData[key] = data +// } +// case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: +// if reflectVal.Uint() == 0 { +// delete(sess.FData, key) +// } else { +// sess.FData[key] = data +// } +// case reflect.Float32, reflect.Float64: +// if reflectVal.Float() == 0 { +// delete(sess.FData, key) +// } else { +// sess.FData[key] = data +// } +// case reflect.Interface, reflect.Ptr: +// if reflectVal.IsNil() { +// delete(sess.FData, key) +// } else { +// sess.FData[key] = data +// } +// default: +// //do nothing +// } +//} +//func (sess *SessionImpl) ApplyChanges() (err error) { +// if sess.FSyncer == nil { +// err = errors.New("cannot sync session,FSyncer is nil!! ") +// return +// } else if syncererr := sess.FSyncer(*sess); syncererr != nil { +// err = errors.Wrap(syncererr, "cannot sync session: ") +// } +// log.Print("session id ", sess.FToken.ToHexString(), ", FUser ", sess.FUser, ", FData ", sess.FData) +// return +//} +// +//func (sess *SessionImpl) NewEncryptReader(r io.Reader) (io.Reader, error) { +// return sess.FBox.NewEncryptReader(r) +//} +//func (sess *SessionImpl) NewDecryptReader(r io.Reader) (io.Reader, error) { +// return sess.FBox.NewDecryptReader(r) +//} +//func (sess *SessionImpl) NewEncryptWriter(w io.Writer) (io.WriteCloser, error) { +// return sess.FBox.NewEncryptWriter(w) +//} +//func (sess *SessionImpl) NewDecryptWriter(w io.Writer) (io.WriteCloser, error) { +// return sess.FBox.NewDecryptWriter(w) +//} +//func (sess *SessionImpl) EncryptedSize(size uint64) (uint64, error) { +// return sess.FBox.EncryptedSize(size) +//} +//func (sess *SessionImpl) DecryptedSize(size uint64) (uint64, error) { +// return sess.FBox.DecryptedSize(size) +//} diff --git a/middleware/internal/session_pool.go b/middleware/internal/session_pool.go new file mode 100644 index 0000000..9a049e0 --- /dev/null +++ b/middleware/internal/session_pool.go @@ -0,0 +1,39 @@ +package internal + +// +//import ( +// "amuz.es/src/mercury/util/mycrypt" +// "sync" +//) +// +//type SessionImplPool struct { +// box mycrypt.SecretBox +// pool sync.Pool +// init sync.Once +//} +// +//func (p *SessionImplPool) Acquire() *SessionImpl { +// p.init.Do(func() { +// p.pool.New = p.alloc +// }) +// b := p.pool.Get().(*SessionImpl) +// return b +//} +// +//func (p *SessionImplPool) alloc() any { +// impl := &SessionImpl{ +// FBox: p.box, +// } +// return impl +//} +// +//func (p *SessionImplPool) Release(b *SessionImpl) { +// if b == nil { +// return +// } +// b.FData = nil +// b.FSyncer = nil +// b.FToken.Clear() +// b.FUser = nil +// p.pool.Put(b) +//} diff --git a/middleware/secret_box_fasthttp.go b/middleware/secret_box_fasthttp.go new file mode 100644 index 0000000..d06e676 --- /dev/null +++ b/middleware/secret_box_fasthttp.go @@ -0,0 +1,46 @@ +package middleware + +// +//import ( +// "amuz.es/src/mercury/bootstrap" +// "amuz.es/src/mercury/endpoint/middleware2/internal" +// rmisc "amuz.es/src/mercury/endpoint/misc" +// "amuz.es/src/go/logging" +// "amuz.es/src/mercury/util/mycrypt" +// "github.com/pkg/errors" +// "github.com/valyala/fasthttp" +//) +// +//const ( +// SecretBoxMiddlewareName = "secret_box" +//) +// +//type ( +// SecretToolbox interface { +// Encrypt(data map[string]string) (string, error) +// Decrypt(encrypted string) (map[string]string, error) +// } +//) +// +//func SecretBoxFunc(box mycrypt.SecretBox, parentLogger logging.Logger) rmisc.SecretBoxMiddleware { +// if box == nil { +// panic(errors.New("secret box needed")) +// } +// logger := parentLogger.Named("secret_box") +// var toolBox rmisc.SecretToolbox = &internal.SecretToolboxImpl{ +// EncryptWriter: box.NewEncryptWriter, +// DecryptReader: box.NewDecryptReader, +// } +// return func(next rmisc.Router) rmisc.Router { +// return func(ctx *fasthttp.RequestCtx) { +// ctx.SetUserValue(rmisc.SecretBoxContextKey, toolBox) +// logger.Info("injected secret box") +// next(ctx) +// } +// } +// +//} +// +//func NewSecretBox(sessionConfig bootstrap.SessionConfig) rmisc.RouterSecretBox { +// return mycrypt.NewSecretBox(sessionConfig.Secret()) +//} diff --git a/middleware/session_fasthttp.go b/middleware/session_fasthttp.go new file mode 100644 index 0000000..37953d8 --- /dev/null +++ b/middleware/session_fasthttp.go @@ -0,0 +1,192 @@ +package middleware + +// +//import ( +// "amuz.es/src/mercury/bootstrap" +// "amuz.es/src/mercury/common" +// "amuz.es/src/mercury/endpoint/middleware2/internal" +// rmisc "amuz.es/src/mercury/endpoint/misc" +// "amuz.es/src/mercury/service" +// "amuz.es/src/mercury/service/models" +// "amuz.es/src/go/logging" +// "amuz.es/src/go/misc" +// "amuz.es/src/mercury/util/mycrypt" +// "amuz.es/src/go/eighty" +// "bytes" +// "encoding/base64" +// "encoding/binary" +// "github.com/pkg/errors" +// "github.com/valyala/fasthttp" +// "go.uber.org/multierr" +// "strings" +// "time" +//) +// +//func SessionFunc( +// box mycrypt.SecretBox, +// sessionCookieName string, +// sessionExpire time.Duration, +// sessionSecure bool, +// parentLogger logging.Logger, +// userService service.UserService, +//) (rmisc.SessionMiddleware, error) { +// logger := parentLogger.Named("session") +// if box == nil { +// return nil, errors.New("secret box needed") +// } +// sessionKeyReader := readSessionKeyValue(box, sessionCookieName) +// sessionKeyWriter := writeSessionKeyValue(box, sessionCookieName, sessionExpire, sessionSecure) +// sessionSyncGenerator := syncSessioner(logger, +// sessionKeyWriter, +// sessionService.Update, +// sessionService.Delete, +// ) +// var sessionImplPool internal.SessionImplPool +// return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { +// return func(ctx *fasthttp.RequestCtx) { +// sessionData := sessionImplPool.Acquire() +// defer sessionImplPool.Release(sessionData) +// var ( +// token misc.UUID +// userId *int +// user models.User +// data map[string]any +// err error +// ) +// sessionData.FBox = box +// token, err = sessionKeyReader(&ctx.Request) +// if err != nil { +// logger.Debug("cannot extract token, error:", err) +// } else if token.IsZero() { +// logger.Debug("session is empty") +// } else if userId, data, err = sessionService.Get(token); err != nil { +// logger.Debug("cannot find session data from redis token:", token, " error : ", err) +// } else { +// sessionData.FToken = token +// sessionData.FData = data +// } +// +// if userId == nil { +// logger.Debug("user id is nil token:", token.ToHexString()) +// } else if user, err = userService.GetUser(ctx, *userId); err != nil { +// logger.Debug("cannot query user, Id:", userId, " error : ", err) +// } else if user == nil { +// logger.Debug("user not found, Id:", userId) +// } else if user.Status() != common.UserStatusEnabled { +// logger.Debug("disabled user, Id:", userId, " status:", user.Status()) +// } else { +// logger.Info("user : ", user) +// sessionData.FUser = user +// } +// sessionData.FSyncer = sessionSyncGenerator(&ctx.Request, &ctx.Response, token) +// ctx.SetUserValue(rmisc.SessionContextKey, sessionData) +// next(ctx) +// } +// }, nil +//} +// +//func readSessionKeyValue(box mycrypt.SecretBox, cookieName string) func(r *fasthttp.Request) (misc.UUID, error) { +// return func(r *fasthttp.Request) (token misc.UUID, err error) { +// defer func() { +// if panicErr := recover(); panicErr != nil { +// err = multierr.Append(err, panicErr.(error)) +// } +// }() +// sessionCookieValue := r.Header.Cookie(cookieName) +// if len(sessionCookieValue) == 0 { +// err = errors.New("session is Empty") +// return +// } +// decryptor, err := box.NewDecryptReader(base64.NewDecoder(base64.StdEncoding, bytes.NewReader(sessionCookieValue))) +// if err != nil { +// return +// } else if err = binary.Read(decryptor, binary.LittleEndian, &token); err != nil { +// return +// } +// return +// } +//} +// +//func writeSessionKeyValue( +// box mycrypt.SecretBox, cookieName string, +// expireDuration time.Duration, secure bool) func(*fasthttp.Response, []byte, misc.UUID) error { +// cookieWriter := eighty.SetCookieValueFasthttp(cookieName, expireDuration, secure) +// return func(w *fasthttp.Response, hostname []byte, token misc.UUID) (err error) { +// defer func() { +// if panicErr := recover(); panicErr != nil { +// err = multierr.Append(err, panicErr.(error)) +// } +// }() +// if token.IsZero() { +// cookieWriter(w, hostname, "") +// return +// } +// +// var b strings.Builder +// encryptor, err := box.NewEncryptWriter(base64.NewEncoder(base64.StdEncoding, &b)) +// if err != nil { +// return +// } else if err = binary.Write(encryptor, binary.LittleEndian, token[:]); err != nil { +// return +// } +// _ = encryptor.Close() +// +// cookieWriter(w, hostname, b.String()) +// return +// } +//} +// +//func syncSessioner( +// logger logging.Logger, +// sessionKeyWriter func(*fasthttp.Response, []byte, misc.UUID) error, +// sessionDataUpdater func(misc.UUID, *int, map[string]any) error, +// sessionDataDeleter func(misc.UUID) error, +//) func(*fasthttp.Request, *fasthttp.Response, misc.UUID) func(newdata internal.SessionImpl) error { +// return func(r *fasthttp.Request, w *fasthttp.Response, oldToken misc.UUID) func(newdata internal.SessionImpl) error { +// return func(newdata internal.SessionImpl) (err error) { +// newToken := newdata.FToken +// tokenNotChanged := oldToken.Equal(newToken) +// if tokenNotChanged { +// // do nothing +// } else if internalError := sessionKeyWriter(w, r.Host(), newToken); internalError != nil { +// logger.Debug("cannot generate cookie, error:", internalError) +// err = errors.Wrap(internalError, "cannot generate cookie") +// return +// } +// +// if newToken.IsZero() { +// if !tokenNotChanged { +// deleteErr := sessionDataDeleter(oldToken) +// logger.Debug("cannot delete session, error:", deleteErr) +// } +// } else if user := newdata.FUser; user == nil { +// err = sessionDataUpdater(newToken, nil, newdata.FData) +// } else { +// userId := user.Id() +// err = sessionDataUpdater(newToken, &userId, newdata.FData) +// } +// if err != nil { +// logger.Debug("cannot save session data, error:", err) +// err = errors.Wrap(err, "cannot save session data") +// } +// return +// } +// } +//} +// +//func NewSessionMiddleware( +// box rmisc.RouterSecretBox, +// sessionConfig bootstrap.SessionConfig, +// parentLogger rmisc.RouterLogger, +// sessionService service.SessionService, +// userService service.UserService) (rmisc.SessionMiddleware, error) { +// return SessionFunc( +// box, +// sessionConfig.CookieName(), +// sessionConfig.ExpireTTL(), +// sessionConfig.Secure(), +// parentLogger, +// sessionService, +// userService, +// ) +//} diff --git a/routing/context.go b/routing/context.go new file mode 100644 index 0000000..6a17fe3 --- /dev/null +++ b/routing/context.go @@ -0,0 +1,71 @@ +package routing + +import ( + "github.com/fasthttp/router" + "net/http" +) + +type ( + // RouterContext is a fasthttp request url routing context. + RouterContext interface { + // UrlResolver returns a UrlResolver. + UrlResolver() UrlResolver + urlFor() UrlFor + // UrlPrefix returns the URL prefix that is added across the entire routing group. + UrlPrefix() string + withRegister( + r *router.Router, + parentNames []string, + parentPaths []string, + middlewares ...Middleware, + ) RouterRegistry + // BuildRouter returns a RouterRegistry for register grouped routing. + BuildRouter(errHandleMiddleware Middleware) RouterRegistry + } + + routerContextImpl struct { + urlPrefix string + reverseRouter UrlFor + } +) + +func (ctx *routerContextImpl) UrlResolver() UrlResolver { return ctx.reverseRouter.ToResolver() } +func (ctx *routerContextImpl) urlFor() UrlFor { return ctx.reverseRouter } +func (ctx *routerContextImpl) UrlPrefix() string { return ctx.urlPrefix } +func (ctx *routerContextImpl) withRegister( + r *router.Router, + parentNames []string, + parentPaths []string, + middlewares ...Middleware, +) RouterRegistry { + return &routerRegistryImpl{ + routerContextImpl: ctx, + r: r, + middlewares: middlewares, + parentNames: parentNames, + parentPaths: parentPaths, + } +} +func (c *routerContextImpl) BuildRouter(errHandlemiddleware Middleware) RouterRegistry { + r := router.New() + r.RedirectTrailingSlash = true + r.RedirectFixedPath = true + r.HandleOPTIONS = false + r.HandleMethodNotAllowed = true + if errHandlemiddleware != nil { + r.NotFound = JustCode(http.StatusNotFound, errHandlemiddleware) + r.MethodNotAllowed = JustCode(http.StatusMethodNotAllowed, errHandlemiddleware) + } + return &routerRegistryImpl{routerContextImpl: c, r: r} +} + +// NewRouterContext returns a RouterContext. +func NewRouterContext( + urlPrefix string, + reverseRouter UrlFor, +) (ctx RouterContext) { + return &routerContextImpl{ + urlPrefix: urlPrefix, + reverseRouter: reverseRouter, + } +} diff --git a/routing/helper.go b/routing/helper.go new file mode 100644 index 0000000..f6a596e --- /dev/null +++ b/routing/helper.go @@ -0,0 +1,25 @@ +package routing + +import ( + "amuz.es/src/go/eighty" + "github.com/valyala/fasthttp" +) + +// ApplyMiddlware is a function that applies the given middleware for the handler. +func ApplyMiddlware(source Router, middlewares ...Middleware) (handler Router) { + handler = source + for i := len(middlewares) - 1; i >= 0; i-- { + handler = middlewares[i](handler) + } + return +} + +// JustCode is a simple handler function that just only returns the http status code. +func JustCode(statusHandler eighty.HandledError, middlewares ...Middleware) (handler Router) { + return ApplyMiddlware( + func(ctx *fasthttp.RequestCtx) { + panic(statusHandler) + }, + middlewares..., + ) +} diff --git a/routing/iface.go b/routing/iface.go new file mode 100644 index 0000000..a740a69 --- /dev/null +++ b/routing/iface.go @@ -0,0 +1,14 @@ +package routing + +import ( + "github.com/valyala/fasthttp" +) + +type ( + // Router is an alias type of the fasthttp.RequestHandler. + Router = fasthttp.RequestHandler + // NestedRouter is a generator function that returns the Router handler. + NestedRouter func() Router + // Middleware is a wrapping function that filters the request. + Middleware = func(next Router) Router +) diff --git a/routing/index.go b/routing/index.go new file mode 100644 index 0000000..ac2fe9f --- /dev/null +++ b/routing/index.go @@ -0,0 +1,2 @@ +// Package routing provides the name resolving utility for the fasthttp. +package routing diff --git a/routing/registry.go b/routing/registry.go new file mode 100644 index 0000000..739d3f2 --- /dev/null +++ b/routing/registry.go @@ -0,0 +1,99 @@ +package routing + +import ( + "github.com/fasthttp/router" + "github.com/valyala/fasthttp" + "strings" +) + +type ( + // RouterRegistry is a fasthttp request url routing builder. + RouterRegistry interface { + RouterContext + // Name returns the current group name. + Name() string + // ToContext returns the current group name. + ToContext() RouterContext + // Register is a registration method for http request. + Register( + name string, path string, params []string, + handler Router, + middlewares []Middleware, + methods ...string, + ) + // RegisterNested is a registration method for http request with a router generator. + RegisterNested( + name string, path string, params []string, + routerGenerator NestedRouter, + middlewares []Middleware, + methods ...string, + ) + // Wrap returns a child RouterRegistry with specified name and path. + Wrap(name string, path string, middlewares ...Middleware) RouterRegistry + // Handler is a handler method that process incoming requests. + // It implements the Router interface. + Handler(ctx *fasthttp.RequestCtx) + } + routerRegistryImpl struct { + *routerContextImpl + name string + r *router.Router + middlewares []Middleware + parentNames []string + parentPaths []string + } +) + +func (r *routerRegistryImpl) Handler(ctx *fasthttp.RequestCtx) { + r.r.Handler(ctx) +} +func (r *routerRegistryImpl) Name() string { return r.name } +func (r *routerRegistryImpl) ToContext() RouterContext { return r.routerContextImpl } + +func (r *routerRegistryImpl) Register( + name string, path string, params []string, + handler Router, + middlewares []Middleware, + methods ...string, +) { + + fullPath := r.reverseRouter.MustAddGr(name, path, r.parentNames, r.parentPaths, params...) + mixedRouter := ApplyMiddlware(handler, append(r.middlewares, middlewares...)...) + for _, method := range methods { + r.r.Handle(method, fullPath, mixedRouter) + } +} + +func (r *routerRegistryImpl) RegisterNested( + name string, path string, params []string, + routerGenerator NestedRouter, + middlewares []Middleware, + methods ...string, +) { + r.Register(name, path, params, routerGenerator(), middlewares, methods...) +} + +func (r *routerRegistryImpl) Wrap(name string, path string, middlewares ...Middleware) RouterRegistry { + var ( + newName, newPath []string + ) + if len(name) > 0 { + newName = append(r.parentNames, name) + } else { + newName = r.parentNames + } + if len(path) > 0 { + newPath = append(r.parentPaths, path) + } else { + newPath = r.parentPaths + } + return &routerRegistryImpl{ + routerContextImpl: r.routerContextImpl, + r: r.r, + middlewares: append(r.middlewares, middlewares...), + name: strings.Join(newName, "."), + parentNames: newName, + parentPaths: newPath, + } +} +func (r *routerRegistryImpl) String() string { return r.urlFor().String() } diff --git a/routing/reverse_route.go b/routing/reverse_route.go new file mode 100644 index 0000000..c051a71 --- /dev/null +++ b/routing/reverse_route.go @@ -0,0 +1,247 @@ +package routing + +import ( + "errors" + "path" + "sort" + "strconv" + "strings" +) + +type ( + + // UrlResolver is a URL resolver utility that stores the handler information registered by UrlFor. + UrlResolver interface { + // Get is a resolver function that takes a name and parameters and returns a URL. If the URL is not found, it panics. + Get(urlName string, params ...string) string + // Reverse is a resolver function that takes a name and parameters and returns a URL. + Reverse(urlName string, params ...string) (string, error) + // ReverseWithParams is a resolver function that takes a name and parameters and returns a URL. + ReverseWithParams(urlName string, params []string) (string, error) + // MustReverse is a resolver function that takes a name and parameters and returns a URL. If the URL is not found, it panics. + MustReverse(urlName string, params ...string) string + // MustReverseWithParams is a resolver function that takes a name and parameters and returns a URL. If the URL is not found, it panics. + MustReverseWithParams(urlName string, params []string) string + } + + // UrlFor is a reverse-routing utility that stores the handler information. + UrlFor interface { + UrlResolver + // Add registers name, parameter, and URL for UrlResolver. + // If a duplicate name exists, an error is returned instead of registering. + Add(urlName, urlAddr string, params ...string) (string, error) + // MustAdd registers name, parameter, and URL for UrlResolver. + // If a duplicate name exists, it panics. + MustAdd(urlName, urlAddr string, params ...string) string + // AddGr registers name, parameter, and URL for UrlResolver with nested group infos. + // If a duplicate name exists, an error is returned instead of registering. + AddGr(urlName, urlAddr string, groupNames, groupAddrs []string, params ...string) (string, error) + // MustAddGr registers name, parameter, and URL for UrlResolver with nested group infos. + // If a duplicate name exists, it panics. + MustAddGr(urlName, urlAddr string, groupNames, groupAddrs []string, params ...string) string + // Clear clears all registered reverse-routing infos. + Clear() + // String returns summarized info of registered reverse-routing infos. + String() string + // ToResolver returns a UrlResolver. + ToResolver() UrlResolver + } + routerFragment struct { + url string + params []string + } + + reverseRouter map[string]routerFragment + + reverseRouteResolver func(urlName string, params []string) (string, error) +) + +// NewUrlFor returns a UrlFor. +func NewUrlFor() UrlFor { + router := make(reverseRouter) + return &router +} + +func (rr reverseRouteResolver) Get(urlName string, params ...string) string { + res, err := rr(urlName, params) + if err != nil { + panic(err) + } + return res +} + +func (rr reverseRouteResolver) Reverse(urlName string, params ...string) (string, error) { + return rr(urlName, params) +} + +func (rr reverseRouteResolver) ReverseWithParams(urlName string, params []string) (string, error) { + return rr(urlName, params) +} + +func (rr reverseRouteResolver) MustReverse(urlName string, params ...string) string { + res, err := rr(urlName, params) + if err != nil { + panic(err) + } + return res +} + +func (rr reverseRouteResolver) MustReverseWithParams(urlName string, params []string) string { + res, err := rr(urlName, params) + if err != nil { + panic(err) + } + return res +} + +func (us *reverseRouter) ToResolver() UrlResolver { + return reverseRouteResolver(us.ReverseWithParams) +} + +func (us *reverseRouter) MustReverse(urlName string, params ...string) string { + res, err := us.ReverseWithParams(urlName, params) + if err != nil { + panic(err) + } + return res +} + +func (us *reverseRouter) MustReverseWithParams(urlName string, params []string) string { + res, err := us.ReverseWithParams(urlName, params) + if err != nil { + panic(err) + } + return res +} + +func (us *reverseRouter) MustAdd(urlName, urlAddr string, params ...string) string { + addr, err := us.addInternal(urlName, urlAddr, nil, nil, params) + if err != nil { + panic(err) + } + return addr +} + +func (us *reverseRouter) Add(urlName, urlAddr string, params ...string) (string, error) { + return us.addInternal(urlName, urlAddr, nil, nil, params) +} + +func (us *reverseRouter) MustAddGr(urlName, urlAddr string, groupNames, groupAddrs []string, params ...string) string { + addr, err := us.addInternal(urlName, urlAddr, groupNames, groupAddrs, params) + if err != nil { + panic(err) + } + return addr +} + +func (us *reverseRouter) AddGr(urlName, urlAddr string, groupNames, groupAddrs []string, params ...string) (string, error) { + return us.addInternal(urlName, urlAddr, groupNames, groupAddrs, params) +} + +func (us *reverseRouter) Reverse(urlName string, params ...string) (string, error) { + return us.ReverseWithParams(urlName, params) +} + +func (us reverseRouter) addInternal(urlName, urlAddr string, groupNames, groupAddrs, params []string) (string, error) { + if _, ok := us[urlName]; ok { + return "", errors.New("Url already exists. Try to use .Get() method.") + } + routeName := strings.Join(append(groupNames, urlName), ".") + addr := path.Join(append(groupAddrs, urlAddr)...) + + tmpUrl := routerFragment{addr, params} + us[routeName] = tmpUrl + return addr, nil +} + +func (us reverseRouter) Clear() { + for k := range us { + delete(us, k) + } +} + +func (us *reverseRouter) Get(urlName string, params ...string) string { + url, err := us.ReverseWithParams(urlName, params) + if err != nil { + panic(err) + } + return url +} + +func (us reverseRouter) ReverseWithParams(urlName string, params []string) (string, error) { + if len(params) != len(us[urlName].params) { + return "", errors.New("Bad Url Reverse: mismatch params for URL: " + urlName) + } + res := us[urlName].url + for i, val := range params { + res = strings.Replace(res, us[urlName].params[i], val, 1) + } + return res, nil +} + +func (us reverseRouter) String() (ret string) { + var ( + numOfRoutes = len(us) + builder strings.Builder + needSort bool + ) + defer func() { + ret = builder.String() + }() + + builder.WriteString(strconv.FormatInt(int64(numOfRoutes), 10)) + builder.WriteByte(' ') + switch numOfRoutes { + case 0: + builder.WriteString("route") + return + case 1: + builder.WriteString("route:\n") + default: + builder.WriteString("routes:\n") + needSort = true + } + + fragmentStringer := func(builder *strings.Builder, idx int, key string, value routerFragment) { + builder.WriteByte('\t') + builder.WriteString(key) + builder.WriteByte('(') + builder.WriteString(value.url) + builder.WriteString(") [") + numOfParams := len(value.params) + for pathIdx := 0; pathIdx < numOfParams; pathIdx++ { + builder.WriteString(value.params[pathIdx]) + if pathIdx < numOfParams-1 { + builder.WriteByte(',') + } + } + builder.WriteByte(']') + if idx < numOfRoutes-1 { + builder.WriteByte('\n') + } + } + if needSort { + routeKeys := make([]string, 0, numOfRoutes) + for key := range us { + routeKeys = append(routeKeys, key) + } + + sort.SliceStable(routeKeys, func(i, j int) bool { + return len(us[routeKeys[i]].url) < len(us[routeKeys[j]].url) + }) + for i, key := range routeKeys { + fragmentStringer(&builder, i, key, us[key]) + } + } else { + i := 0 + for key, value := range us { + fragmentStringer(&builder, i, key, value) + i++ + } + } + return +} + +func (us reverseRouter) getParamName(urlName string, num int) string { + return us[urlName].params[num] +} diff --git a/serve/fasthttp.go b/serve/fasthttp.go new file mode 100644 index 0000000..5ad297a --- /dev/null +++ b/serve/fasthttp.go @@ -0,0 +1,210 @@ +package serve + +import ( + "amuz.es/src/go/eighty" + "amuz.es/src/go/misc/strutil" + "bytes" + "crypto/subtle" + "github.com/valyala/fasthttp" + "net/http" + "net/textproto" + "time" +) + +const sniffLen = 512 + +var ( + etagPrefix = []byte("W/") + + unixEpochTime = time.Unix(0, 0) +) + +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(r *fasthttp.Request, w *fasthttp.Response, modtime time.Time, etag []byte) (done bool) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(r, w) + if ch == condNone { + ch = checkIfUnmodifiedSince(r, modtime) + } + if ch == condFalse { + w.SetStatusCode(http.StatusPreconditionFailed) + return true + } + switch checkIfNoneMatch(etag, r, w) { + case condFalse: + if subtle.ConstantTimeCompare(r.Header.Method(), eighty.MethodGET) == 1 || subtle.ConstantTimeCompare(r.Header.Method(), eighty.MethodHEAD) == 1 { + writeNotModified(w) + return true + } else { + w.SetStatusCode(http.StatusPreconditionFailed) + return true + } + case condNone: + if checkIfModifiedSince(r, modtime) == condFalse { + writeNotModified(w) + return true + } + } + return +} + +func checkIfNoneMatch(providedEtag []byte, r *fasthttp.Request, w *fasthttp.Response) condResult { + inm := r.Header.Peek("If-None-Match") + if len(inm) == 0 { + return condNone + } + buf := inm + for { + buf = textproto.TrimBytes(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) + if len(etag) == 0 { + break + } + if etagWeakMatch(etag, providedEtag) { + return condFalse + } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(r *fasthttp.Request, modtime time.Time) condResult { + if subtle.ConstantTimeCompare(r.Header.Method(), eighty.MethodGET) != 1 && subtle.ConstantTimeCompare(r.Header.Method(), eighty.MethodHEAD) != 1 { + return condNone + } + ims := r.Header.Peek("If-Modified-Since") + if len(ims) == 0 || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(strutil.B2S(ims)) + if err != nil { + return condNone + } + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if modtime.Before(t.Add(1 * time.Second)) { + return condFalse + } + return condTrue +} + +func checkIfUnmodifiedSince(r *fasthttp.Request, modtime time.Time) condResult { + ius := r.Header.Peek("If-Unmodified-Since") + if len(ius) == 0 || isZeroTime(modtime) { + return condNone + } + if t, err := http.ParseTime(strutil.B2S(ius)); err == nil { + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if modtime.Before(t.Add(1 * time.Second)) { + return condTrue + } + return condFalse + } + return condNone +} + +func checkIfMatch(r *fasthttp.Request, w *fasthttp.Response) condResult { + im := r.Header.Peek("If-Match") + if len(im) == 0 { + return condNone + } + for { + im = textproto.TrimBytes(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue + } + etag, remain := scanETag(im) + if len(etag) == 0 { + break + } + if etagStrongMatch(etag, w.Header.Peek("Etag")) { + return condTrue + } + im = remain + } + + return condFalse +} + +func scanETag(s []byte) (etag []byte, remain []byte) { + s = textproto.TrimBytes(s) + start := 0 + if bytes.HasPrefix(s, etagPrefix) { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return + } + } + return +} + +func etagStrongMatch(a, b []byte) bool { + return subtle.ConstantTimeCompare(a, b) == 1 && len(a) > 0 && a[0] == '"' +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b []byte) bool { + return subtle.ConstantTimeCompare( + bytes.TrimPrefix(a, etagPrefix), + bytes.TrimPrefix(b, etagPrefix), + ) == 1 +} + +func writeNotModified(w *fasthttp.Response) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + w.Header.Del("Content-Type") + w.Header.Del("Content-Length") + if len(w.Header.Peek("Etag")) > 0 { + w.Header.Del("Last-Modified") + } + w.SetStatusCode(http.StatusNotModified) + w.SkipBody = true +} + +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} diff --git a/serve/iface.go b/serve/iface.go new file mode 100644 index 0000000..39b373b --- /dev/null +++ b/serve/iface.go @@ -0,0 +1,27 @@ +package serve + +import ( + "io" + "os" +) + +type ( + // File defines a abstract file object,which provides caching hash. + File interface { + io.Closer + io.Reader + io.Seeker + // Name returns abstract path of file. + Name() string + + // Stat returns a os.FileInfo describing the named file. + // If there is an error, it will be of type *os.PathError. + Stat() (os.FileInfo, error) + + // Hash returns a file content checksum that is used for the e-tags. + Hash() []byte + } + + // FileProvider is an alias of the file resolver function. + FileProvider = func(name string) (File, error) +) diff --git a/serve/index.go b/serve/index.go new file mode 100644 index 0000000..8f436f9 --- /dev/null +++ b/serve/index.go @@ -0,0 +1,2 @@ +// Package serve provides static file serving functions. +package serve diff --git a/serve/link.go b/serve/link.go new file mode 100644 index 0000000..1c93000 --- /dev/null +++ b/serve/link.go @@ -0,0 +1,40 @@ +package serve + +import ( + "io" + "net/http" + "time" + _ "unsafe" +) + +//go:linkname toHTTPError net/http.toHTTPError +//go:nosplit +func toHTTPError(error) (string, int) + +//go:linkname serveContent net/http.serveContent +//go:nosplit +func serveContent( + w http.ResponseWriter, + r *http.Request, + name string, + modtime time.Time, + sizeFunc func() (int64, error), + content io.ReadSeeker, +) + +// ServeFile handles static file response. +func ServeFile(w http.ResponseWriter, r *http.Request, f http.File) { + stat, err := f.Stat() + if err != nil { + msg, code := toHTTPError(err) + http.Error(w, msg, code) + return + } else if stat.IsDir() { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + + // serveContent will check modification time + sizeFunc := func() (int64, error) { return stat.Size(), nil } + serveContent(w, r, stat.Name(), stat.ModTime(), sizeFunc, f) +} diff --git a/serve/static_fasthttp.go b/serve/static_fasthttp.go new file mode 100644 index 0000000..ac596b1 --- /dev/null +++ b/serve/static_fasthttp.go @@ -0,0 +1,153 @@ +package serve + +import ( + "amuz.es/src/go/eighty" + "amuz.es/src/go/misc/strutil" + "bytes" + "crypto/subtle" + "github.com/valyala/fasthttp" + "io" + "log" + "mime" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "sync" + "time" +) + +const cacheControlPrefix = "public, max-age=" + +type staticFileHandler struct { + contentPool sync.Pool + pathPrefix []byte + baseVersion int64 + cacheFiller func(header *fasthttp.ResponseHeader, versionPrefix, discretePrefix []byte) + fileProvider FileProvider +} + +// NewStaticFileHandler returns a static handler for fasthttp with advanced caching and url prefix stripping. +func NewStaticFileHandler(debug bool, urlPrefix, staticUrl string, startupTime time.Time, fileProvider FileProvider) fasthttp.RequestHandler { + h := staticFileHandler{ + pathPrefix: []byte(path.Join(urlPrefix, staticUrl)), + baseVersion: startupTime.Unix(), + fileProvider: fileProvider, + } + + h.contentPool.New = func() any { return [sniffLen]byte{} } + + if debug { + h.cacheFiller = func(header *fasthttp.ResponseHeader, versionPrefix, discretePrefix []byte) {} + } else { + h.cacheFiller = func(header *fasthttp.ResponseHeader, versionPrefix, discretePrefix []byte) { + if receivedVersion, err := strconv.ParseInt(strutil.B2S(versionPrefix), 10, 64); err != nil || receivedVersion > 0 && receivedVersion < h.baseVersion { + return + } + + var cacheDuration int64 = 2592000 + if pushedDuration, err := strconv.ParseInt(strutil.B2S(discretePrefix), 10, 64); err == nil && pushedDuration > 0 && pushedDuration > cacheDuration { + cacheDuration = pushedDuration + } + //add header + header.Set(eighty.CacheControlHeader, cacheControlPrefix+strconv.FormatInt(cacheDuration, 10)) + header.Set(eighty.ExpiresHeader, strconv.FormatInt(cacheDuration, 10)) + } + } + return h.Handle +} + +func (h *staticFileHandler) resolvFile(pathData []byte) (f File, stat os.FileInfo) { + var err error + // normalizePath + requestPath := strutil.B2S(bytes.TrimPrefix(pathData, h.pathPrefix)) + strippedPath := path.Clean(requestPath) + if f, err = h.fileProvider(strippedPath); err != nil || f == nil { + panic(eighty.HandledErrorNotFound) + } + + // resolvFileInfo + if stat, err = f.Stat(); err != nil { + _, code := toHTTPError(err) + err, _ = eighty.HandledErrorCodeOf(code) + panic(err) + } else if stat.IsDir() { + panic(eighty.HandledErrorNotFound) + } + return +} + +func (h *staticFileHandler) resolvContentType(name string, header *fasthttp.ResponseHeader, file File) { + if ctypes := header.ContentType(); len(ctypes) > 0 { + return + } + + if ctype := mime.TypeByExtension(filepath.Ext(name)); len(ctype) > 0 { + header.SetContentType(ctype) + return + } + + buf := h.contentPool.Get().([sniffLen]byte) + defer h.contentPool.Put(buf) + + // read a chunk to decide between utf-8 text and binary + if n, _ := file.Read(buf[:]); n > 0 { + // rewind to output whole file + if _, seekErr := file.Seek(0, io.SeekStart); seekErr != nil { + log.Printf("file %s seeker can't seek", name) + panic(http.StatusInternalServerError) + } + header.SetContentType(http.DetectContentType(buf[:n])) + } else { + header.SetContentType("application/octet-stream") + } +} + +// Handle is a handler method for http request. +func (h *staticFileHandler) Handle(ctx *fasthttp.RequestCtx) { + + var ( + // step 1 stripping url + // step 2 openfile + f, stat = h.resolvFile(ctx.Path()) + queryArgs = ctx.QueryArgs() + modTime = stat.ModTime() + name = f.Name() + respHdr = &ctx.Response.Header + etag []byte + ) + // step 3 cache control + h.cacheFiller(respHdr, queryArgs.Peek("v"), queryArgs.Peek("d")) + respHdr.Set(eighty.VaryHeader, eighty.UserAgentHeader) + // step 4 serve file + if !isZeroTime(modTime) { + respHdr.SetLastModified(modTime) + } + + etag = f.Hash() + if len(etag) > 0 { + respHdr.SetBytesV(eighty.EtagHeader, etag) + } + + // serveContent will check modification time + if checkPreconditions(&ctx.Request, &ctx.Response, modTime, etag) { + // not modified + return + } + + // step 5 resolve content type + h.resolvContentType(name, respHdr, f) + + if size := stat.Size(); size <= 0 { + ctx.Response.SkipBody = true + ctx.SetStatusCode(http.StatusNoContent) + return + } else if subtle.ConstantTimeCompare(ctx.Method(), eighty.MethodHEAD) == 1 { + ctx.Response.SkipBody = true + ctx.SetStatusCode(http.StatusOK) + } else { + ctx.SetBodyStream(f, int(size)) + ctx.SetStatusCode(http.StatusOK) + } +} diff --git a/serve/stub.s b/serve/stub.s new file mode 100644 index 0000000..745eede --- /dev/null +++ b/serve/stub.s @@ -0,0 +1,4 @@ +// The runtime package uses //go:linkname to push a few functions into this +// package but we still need a .s file so the Go tool does not pass -complete +// to the go tool compile so the latter does not complain about Go functions +// with no bodies. \ No newline at end of file diff --git a/status_resp.go b/status_resp.go new file mode 100644 index 0000000..fa086a9 --- /dev/null +++ b/status_resp.go @@ -0,0 +1,263 @@ +package eighty + +import ( + "amuz.es/src/go/misc" + "github.com/valyala/fasthttp" + "log" +) + +type ( + // HandledError is a http status handler type + HandledError int + // PageRenderer is a function interface type for the http status page renderer. + PageRenderer = func(r *fasthttp.RequestCtx, name string, context map[string]any) error +) + +var ( + // Is the interface compatible with the actual dto + _ error = (*HandledError)(nil) +) + +// Collection of predefined HandledError. +const ( + // HandledErrorBadRequest : 400, BadRequest http status + HandledErrorBadRequest HandledError = 400 + // HandledErrorUnauthorized : 401, Unauthorized http status + HandledErrorUnauthorized HandledError = 401 + // HandledErrorForbidden : 403, Forbidden http status + HandledErrorForbidden HandledError = 403 + // HandledErrorNotFound : 404, NotFound http status + HandledErrorNotFound HandledError = 404 + // HandledErrorMethodNotAllowed : 405, MethodNotAllowed http status + HandledErrorMethodNotAllowed HandledError = 405 + // HandledErrorNotAcceptable : 406, NotAcceptable http status + HandledErrorNotAcceptable HandledError = 406 + // HandledErrorRequestTimeout : 408, RequestTimeout http status + HandledErrorRequestTimeout HandledError = 408 + // HandledErrorGone : 410, Gone http status + HandledErrorGone HandledError = 410 + // HandledErrorTooManyRequests : 429, TooManyRequests http status + HandledErrorTooManyRequests HandledError = 429 + // HandledErrorInternalServerError : 500, InternalServerError http status + HandledErrorInternalServerError HandledError = 500 + // HandledErrorNotImplemented : 501, NotImplemented http status + HandledErrorNotImplemented HandledError = 501 + // HandledErrorBadGateway : 502, BadGateway http status + HandledErrorBadGateway HandledError = 502 + // HandledErrorServiceUnavailable : 503, ServiceUnavailable http status + HandledErrorServiceUnavailable HandledError = 503 + // HandledErrorGatewayTimeout : 504, GatewayTimeout http status + HandledErrorGatewayTimeout HandledError = 504 +) + +// HandledErrorCodeOf is the conversion function with the http status code to HandledError. +func HandledErrorCodeOf(value int) (HandledError, bool) { + switch value { + case int(HandledErrorBadRequest): + return HandledErrorBadRequest, true + case int(HandledErrorUnauthorized): + return HandledErrorUnauthorized, true + case int(HandledErrorForbidden): + return HandledErrorForbidden, true + case int(HandledErrorNotFound): + return HandledErrorNotFound, true + case int(HandledErrorMethodNotAllowed): + return HandledErrorMethodNotAllowed, true + case int(HandledErrorNotAcceptable): + return HandledErrorNotAcceptable, true + case int(HandledErrorGone): + return HandledErrorGone, true + case int(HandledErrorNotImplemented): + return HandledErrorNotImplemented, true + case int(HandledErrorBadGateway): + return HandledErrorBadGateway, true + case int(HandledErrorServiceUnavailable): + return HandledErrorServiceUnavailable, true + case int(HandledErrorGatewayTimeout): + return HandledErrorGatewayTimeout, true + case int(HandledErrorInternalServerError): + return HandledErrorInternalServerError, true + default: + return HandledErrorInternalServerError, false + } +} + +// HandledErrorOf is the conversion function with the generic error object to HandledError. +func HandledErrorOf(value any) (HandledError, bool) { + switch value { + case HandledErrorBadRequest: + return HandledErrorBadRequest, true + case HandledErrorUnauthorized: + return HandledErrorUnauthorized, true + case HandledErrorForbidden: + return HandledErrorForbidden, true + case HandledErrorNotFound: + return HandledErrorNotFound, true + case HandledErrorMethodNotAllowed: + return HandledErrorMethodNotAllowed, true + case HandledErrorNotAcceptable: + return HandledErrorNotAcceptable, true + case HandledErrorRequestTimeout: + return HandledErrorRequestTimeout, true + case HandledErrorGone: + return HandledErrorGone, true + case HandledErrorTooManyRequests: + return HandledErrorTooManyRequests, true + case HandledErrorNotImplemented: + return HandledErrorNotImplemented, true + case HandledErrorBadGateway: + return HandledErrorBadGateway, true + case HandledErrorServiceUnavailable: + return HandledErrorServiceUnavailable, true + case HandledErrorGatewayTimeout: + return HandledErrorGatewayTimeout, true + case HandledErrorInternalServerError: + return HandledErrorInternalServerError, true + default: + return HandledErrorInternalServerError, false + } +} + +// StatusCode returns the http status code. +func (handler HandledError) StatusCode() int { + return int(handler) +} + +// StatusMessage returns the http status message. +func (handler HandledError) StatusMessage() (msg string) { + switch handler { + case HandledErrorBadRequest: + msg = "Bad Request" + case HandledErrorUnauthorized: + msg = "Unauthorized" + case HandledErrorForbidden: + msg = "Forbidden" + case HandledErrorNotFound: + msg = "Not Found" + case HandledErrorMethodNotAllowed: + msg = "Method Not Allowed" + case HandledErrorNotAcceptable: + msg = "Not Acceptable" + case HandledErrorRequestTimeout: + msg = "Request Timeout" + case HandledErrorGone: + msg = "Gone" + case HandledErrorTooManyRequests: + msg = "Too Many Requests" + case HandledErrorNotImplemented: + msg = "Not Implemented" + case HandledErrorBadGateway: + msg = "Bad Gateway" + case HandledErrorServiceUnavailable: + msg = "Service Unavailable" + case HandledErrorGatewayTimeout: + msg = "Gateway Timeout" + case HandledErrorInternalServerError: + fallthrough + default: + msg = "Internal Server Error" + } + return +} + +// StatusDescription returns the http status description. +func (handler HandledError) StatusDescription() (msg string) { + switch handler { + case HandledErrorBadRequest: + msg = "The request could not be understood by the server due to malformed syntax." + case HandledErrorUnauthorized: + msg = "The request requires user authentication." + case HandledErrorForbidden: + msg = "The server understood the request, but is refusing to fulfill it." + case HandledErrorNotFound: + msg = "The server has not found anything matching the Request-URI." + case HandledErrorMethodNotAllowed: + msg = "The method specified in the Request-Line is not allowed for the resource identified by the Request-URI." + case HandledErrorNotAcceptable: + msg = "The resource identified by the request is only capable of generating response entities which have content characteristics not acceptable according to the accept headers sent in the request." + case HandledErrorRequestTimeout: + msg = "The client did not produce a request within the time that the server was prepared to wait." + case HandledErrorGone: + msg = "The requested resource is no longer available at the server and no forwarding address is known." + case HandledErrorTooManyRequests: + msg = "The client has sent too many requests in a given amount of time." + case HandledErrorNotImplemented: + msg = "The server does not support the functionality required to fulfill the request." + case HandledErrorBadGateway: + msg = "The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." + case HandledErrorServiceUnavailable: + msg = "The server is currently unable to handle the request due to a temporary overloading or maintenance of the server." + case HandledErrorGatewayTimeout: + msg = "The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." + case HandledErrorInternalServerError: + fallthrough + default: + msg = "The server encountered an unexpected condition which prevented it from fulfilling the request." + } + return +} + +// RenderPage is a html page renderer function, that follows the http status code with context. +func (handler HandledError) RenderPage(ctx *fasthttp.RequestCtx, templateRenderer PageRenderer, err error) { + defer func() { + if len(ctx.Response.Header.ContentType()) == 0 { + ctx.SetContentType(HtmlContentUTF8Type[0]) + } + ctx.SetStatusCode(handler.StatusCode()) + }() + + tmplCtx := map[string]any{ + "title": handler.StatusMessage(), + "description": handler.StatusDescription(), + "nofollow": true, + } + if err != nil { + tmplCtx["message"] = err.Error() + } + + if err := templateRenderer(ctx, "error", tmplCtx); err != nil { + log.Print("cannot render error page: ", err) + } + return +} + +// RenderAPI is a json renderer function, that follows the http status code with context. +func (handler HandledError) RenderAPI(ctx *fasthttp.RequestCtx, _ error) { + defer func() { + if len(ctx.Response.Header.ContentType()) == 0 { + ctx.SetContentType(JsonContentType[0]) + } + ctx.SetStatusCode(handler.StatusCode()) + }() + stream := misc.JSONCodec.BorrowStream(ctx) + defer misc.JSONCodec.ReturnStream(stream) + stream.WriteObjectStart() + stream.WriteObjectField("code") + stream.WriteInt(handler.StatusCode()) + stream.WriteMore() + stream.WriteObjectField("message") + stream.WriteString(handler.StatusMessage()) + stream.WriteObjectEnd() + _ = stream.Flush() + return +} + +// Error implements the built-in interface type error. +func (handler HandledError) Error() string { + return handler.StatusMessage() +} + +// WrapHandledError is the panic handler function with a thrown panic object. +func WrapHandledError(panicObj any) (handler HandledError, err error) { + var panicObjIsErr bool + if err, panicObjIsErr = panicObj.(error); panicObjIsErr { + var errIsDefined bool + if handler, errIsDefined = HandledErrorOf(err); errIsDefined { + err = nil + } + } else { + log.Printf("panic object(%v) isn't error interface", panicObj) + handler = HandledErrorInternalServerError + } + return +}