From 27f418354408584ed62266682d3e0a12ad210a28 Mon Sep 17 00:00:00 2001 From: eudore Date: Thu, 31 Aug 2023 19:17:03 +0800 Subject: [PATCH] up acion --- .../workflows/{go.yml => github-action.yml} | 159 +- .github/workflows/golang-lint.yml | 59 + CHANGELOG.md | 50 +- README.md | 4 +- _example/README.md | 30 +- _example/appCommand.go | 22 +- _example/appDaemon.go | 25 +- _example/appExtend.go | 3 +- _example/appHealth.go | 37 + _example/appNew.go | 2 + _example/appNew2.go | 116 ++ _example/appNotify.go | 38 - _example/appReload.go | 71 +- _example/appRestart.go | 60 + _example/appStatic.go | 29 +- _example/app_test.go | 8 +- _example/benchBuffer_test.go | 75 + _example/benchFuncRun_test.go | 87 ++ _example/benchFunc_test.go | 6 +- _example/benchHeader_test.go | 1 + _example/benchName_test.go | 51 + _example/client_test.go | 357 +++-- _example/configStd.go | 38 +- _example/config_test.go | 74 +- _example/contextPush.go | 55 +- _example/context_test.go | 186 ++- _example/controllerAutoRoute.go | 13 - _example/controller_test.go | 17 + _example/converter_test.go | 359 ----- _example/database_test.go | 69 + _example/funccreator_test.go | 189 +++ _example/handler_test.go | 119 +- _example/handlerdata_test.go | 385 ++--- _example/logger_test.go | 602 ++++---- _example/middleware2_test.go | 701 +++++++++ _example/middleware3_test.go | 125 ++ _example/middleware_test.go | 974 ++---------- _example/otherNotify.go | 219 +++ _example/policy_test.go | 79 +- _example/protobuf_test.go | 9 +- _example/routerDelete.go | 26 +- _example/routerStd.go | 45 +- _example/router_test.go | 52 +- _example/routerstd_test.go | 22 +- _example/serverConfig.go | 16 - _example/server_test.go | 23 +- _example/util16_test.go | 29 - _example/util_test.go | 382 +++-- app.go | 118 +- client.go | 954 +++++------- clientoption.go | 573 ++++++++ config.go | 489 ++++-- const.go | 782 ++++++---- context.go | 662 +++++---- controller.go | 156 +- converter.go | 872 ----------- converter2.go | 190 --- daemon/command.go | 232 +++ daemon/daemon.go | 111 ++ daemon/restart.go | 114 ++ daemon/signal.go | 70 + database.go | 29 +- funccreator.go | 653 +++++++++ funcdefine.go | 562 +++++++ go.mod | 2 +- handler.go | 816 +++------- handlerdata.go | 759 ++-------- handlerdata2.go | 328 +++++ handlerextender.go | 734 +++++++++ logger.go | 1306 +++++------------ loggerformatter.go | 627 ++++++++ loggerhandler.go | 472 ++++++ middleware/README.md | 23 +- middleware/all.go | 109 +- middleware/black.go | 21 +- middleware/breaker.go | 14 +- middleware/cache.go | 138 +- middleware/compress.go | 276 ++-- middleware/const.go | 66 + middleware/cors.go | 48 +- middleware/csrf.go | 56 +- middleware/doc.go | 122 +- middleware/look.go | 99 +- middleware/nethttp.go | 35 +- middleware/pprof.go | 110 +- middleware/rate.go | 51 +- middleware/referer.go | 70 +- middleware/rewrite.go | 4 +- middleware/router.go | 8 +- policy/pbac.go | 53 +- policy/policy.go | 60 +- policy/util.go | 46 +- router.go | 541 ++++--- routerstd.go | 102 +- server.go | 144 +- util.go | 795 +++++----- util_16.go | 66 - value.go | 701 +++++++++ 98 files changed, 12865 insertions(+), 8602 deletions(-) rename .github/workflows/{go.yml => github-action.yml} (76%) create mode 100644 .github/workflows/golang-lint.yml create mode 100644 _example/appHealth.go create mode 100644 _example/appNew2.go delete mode 100644 _example/appNotify.go create mode 100644 _example/appRestart.go create mode 100644 _example/benchBuffer_test.go create mode 100644 _example/benchFuncRun_test.go create mode 100644 _example/benchName_test.go delete mode 100644 _example/converter_test.go create mode 100644 _example/database_test.go create mode 100644 _example/funccreator_test.go create mode 100644 _example/middleware2_test.go create mode 100644 _example/middleware3_test.go create mode 100644 _example/otherNotify.go delete mode 100644 _example/serverConfig.go delete mode 100644 _example/util16_test.go create mode 100644 clientoption.go delete mode 100644 converter.go delete mode 100644 converter2.go create mode 100644 daemon/command.go create mode 100644 daemon/daemon.go create mode 100644 daemon/restart.go create mode 100644 daemon/signal.go create mode 100644 funccreator.go create mode 100644 funcdefine.go create mode 100644 handlerdata2.go create mode 100644 handlerextender.go create mode 100644 loggerformatter.go create mode 100644 loggerhandler.go create mode 100644 middleware/const.go delete mode 100644 util_16.go create mode 100644 value.go diff --git a/.github/workflows/go.yml b/.github/workflows/github-action.yml similarity index 76% rename from .github/workflows/go.yml rename to .github/workflows/github-action.yml index ac4b343..5b53f7a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/github-action.yml @@ -1,83 +1,76 @@ -name: Run Tests -on: - push: - branches: ["master"] - pull_request: - branches: ["master"] -permissions: - contents: read -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 - with: - go-version-file: go.mod - - uses: golangci/golangci-lint-action@v3 - continue-on-error: true - with: - args: --verbose --disable errcheck - test: - needs: lint - strategy: - max-parallel: 2 - matrix: - os: [macos, ubuntu] - go: ["1.18"] - include: - - os: macos - gopath: /Users/runner/go - gocache: /Users/runner/Library/Caches/go-build - - os: ubuntu - gopath: /home/runner/go - gocache: /home/runner/.cache/go-build - name: ${{ matrix.os }} @ Go ${{ matrix.go }} - runs-on: ${{ matrix.os }}-latest - env: - GO111MODULE: off - GOPATH: ${{ matrix.gopath }}:${{ github.workspace }} - CGO_ENABLED: 1 - WORKDIR: src/github.com/eudore/eudore/ - PACKAGES: github.com/eudore/eudore,github.com/eudore/eudore/middleware - defaults: - run: - working-directory: ${{ env.WORKDIR }} - steps: - - uses: actions/checkout@v3 - with: - path: ${{ env.WORKDIR }} - - uses: actions/setup-go@v3 - with: - go-version: ${{ matrix.go }} - - uses: actions/cache@v3 - with: - key: ${{ runner.os }}-go-${{ hashFiles('src/github.com/eudore/eudore/go.mod') }} - restore-keys: ${{ runner.os }}-go- - path: | - ${{ matrix.gopath }} - ${{ matrix.gocache }} - - name: Run Debug - run: rm -f _example/xxxlogger_test.go - - name: Run Go Get - run: for pkg in $(go list -json _example/*_test.go | jq -r '.XTestImports[]' | grep -E "github|monkey" | grep -v eudore); do go get -v $pkg; done - - name: Run Tests - run: set -o pipefail;go test -v -timeout=1m -race -cover -coverprofile=coverage.out -coverpkg='${{ env.PACKAGES }}' _example/*_test.go | tee output; - - name: Run Notice - if: ${{ strategy.job-index == 0 }} - run: 'echo "::notice::$(tail -3 output | grep "coverage: ")"' - - name: Run Coverage - if: ${{ strategy.job-index == 0 }} - run: go tool cover -html coverage.out -o coverage.html - - uses: actions/upload-artifact@v3 - if: ${{ strategy.job-index == 0 }} - with: - name: Coverage-eudore-${{ github.ref_name }}-${{ matrix.os }}-${{ matrix.go }} - path: ${{ env.WORKDIR }}/coverage.html - - uses: codecov/codecov-action@v3 - if: ${{ strategy.job-index == 0xf }} - with: - token: ${{ secrets.CODECOV_TOKEN }} - files: ${{ env.WORKDIR }}/coverage.out - flags: ${{ matrix.os }},go-${{ matrix.go }} - verbose: true +name: Run Tests +on: + push: + branches: ["master"] + pull_request: + branches: ["master"] +permissions: + contents: read +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v3 + with: + go-version-file: go.mod + - uses: golangci/golangci-lint-action@v3 + continue-on-error: true + with: + args: --verbose -c .github/workflows/golang-lint.yml + test: + strategy: + max-parallel: 2 + matrix: + os: [macos, ubuntu] + go: ["1.18"] + include: + - os: macos + gopath: /Users/runner/go + gocache: /Users/runner/Library/Caches/go-build + - os: ubuntu + gopath: /home/runner/go + gocache: /home/runner/.cache/go-build + name: ${{ matrix.os }} @ Go ${{ matrix.go }} + runs-on: ${{ matrix.os }}-latest + env: + GO111MODULE: off + GOPATH: ${{ matrix.gopath }}:${{ github.workspace }} + CGO_ENABLED: 1 + WORKDIR: src/github.com/eudore/eudore/ + PACKAGES: github.com/eudore/eudore,github.com/eudore/eudore/middleware + defaults: + run: + working-directory: ${{ env.WORKDIR }} + steps: + - uses: actions/checkout@v3 + with: + path: ${{ env.WORKDIR }} + - uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - uses: actions/cache@v3 + with: + key: ${{ runner.os }}-go-${{ hashFiles('src/github.com/eudore/eudore/go.mod') }} + restore-keys: ${{ runner.os }}-go- + path: | + ${{ matrix.gopath }} + ${{ matrix.gocache }} + - name: Run Debug + run: rm -f _example/logger_test.go + - name: Run Go Get + run: for pkg in $(go list -json _example/*_test.go | jq -r '.XTestImports[]' | grep -E "github|monkey" | grep -v eudore); do go get -v $pkg; done + - name: Run Tests + run: go test -v -timeout=1m -race -cover -coverprofile=coverage.out -coverpkg='${{ env.PACKAGES }}' _example/*_test.go | tee output + - name: Run Notice + if: ${{ strategy.job-index == 0 }} + run: | + 'echo "::notice::$(tail -3 output | grep "coverage: ")"' + - name: Run Coverage + if: ${{ strategy.job-index == 0 }} + run: go tool cover -html coverage.out -o coverage.html + - uses: actions/upload-artifact@v3 + if: ${{ strategy.job-index == 0 }} + with: + name: Coverage-eudore-${{ github.ref_name }}-${{ matrix.os }}-${{ matrix.go }} + path: ${{ env.WORKDIR }}/coverage.html \ No newline at end of file diff --git a/.github/workflows/golang-lint.yml b/.github/workflows/golang-lint.yml new file mode 100644 index 0000000..d18f713 --- /dev/null +++ b/.github/workflows/golang-lint.yml @@ -0,0 +1,59 @@ +run: + timeout: 2m + skip-dirs: + - database + - database2 + skip-files: + - protobuf.go + - middleware/admin.go + - middleware/dump.go +issues: + max-same-issues: 10 +linters: + enable-all: true + disable: + - golint + - deadcode + - nosnakecase + - ifshort + - scopelint + - maligned + - structcheck + - exhaustivestruct + - varcheck + - interfacer + - errcheck + - varnamelen + - wrapcheck + - nlreturn + - ireturn + - interfacebloat + - gochecknoglobals + - nonamedreturns + - forcetypeassert + - exhaustruct + - lll + - gosec + - wsl + - containedctx + - exhaustive + - contextcheck + - tagliatelle + - gomnd + - goerr113 + - nestif + - gocognit + - tagalign +linters-settings: + gocyclo: + min-complexity: 15 + cyclop: + max-complexity: 15 + depguard: + rules: + main: + allow: + - $gostd + - github.com/eudore/eudore + funlen: + statements: 50 \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 502d978..d2bb77c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,24 +1,44 @@ # Change Log -setIntField set Duration type -middleware/cache 不同accept会导致数据格式不同 - Next - Database 实现 -[2022年10月31日] -- App 优化运行输出日志 -- Client 完整重构 -- Logger 修复Sync方法,更新其他组件日志 -- Config 合并实现方法 -- Render/BindProtobuf 无需proto文件进行编码 -- NewContextMessage 新增函数返回请求上下文消息,复用message。 -- middleware/gzip 使用自定义压缩函数,可以使用br压缩。 -- middleware/look 使用自定义data获取函数。 +[2023年8月31日] +- go.mod go版本依赖从1.9升级为1.18,增加error embed 泛型等新版本特性支持。 +- github 使用action配置添加lint和codecov。 +- LoggerStd 修改为Hook结构增强扩展。 +- Client 添加ClietOption/ClintBody,修改请求构造方法。 +- RouterStd 使用Group时参数loggerkind时修改router日志输出级别,加入Metadata接口实现。 +- HandlerData validate使用新fc实现避免反射,完成filter实现过滤或修改数据。 +- FuncCreator 使用泛型重构减少反射使用,额外扩展新函数规则,允许使用逻辑关系式。 +- Context FormValues调用parseForm解析方法修改,不将PostForm和Form复制数据。 +- ConvertTo 移除To/ToMap等转换函数,Get/Set函数优化异常处理。 +- GetAny 修改GetAny相关函数使用泛型实现,重命名移除多余函数。 +- HandlerExtender 默认扩展函数重命名。 +- ResponseWriter 添加WriteString和Unwrap实现。 +- NewFileSystems 处理Dir和Embed的http混合文件对象。 +- NewConfigParseEnvFile 配置解析env文件。 +- NewConfigParseArgs 保存未处理的命令行参数。 +- LoggerStdDataJSON 具有环境变量EnvEudoreDaemonEnable时禁用标准输出。 +- ServerListenConfig 使用DefaultServerListen启动监听。 +- middleware/cache 添加对Accept/Accept-Encoding/304支持。 +- middleware/compress 添加选择压缩方法,忽略小Body和已压缩Mime。 +- middleware/bodylimit 忽略NoBody,使用http.MaxBytesReader限制body长度。 +- daemon 整理启动命令、后台启动、信号处理、热重启,不进行单位测试覆盖。 + +[2022年10月31日](https://github.com/eudore/eudore/tree/de9fd1ea1b653ba6e4f9bb5c108733e3142cadf6) +- App 优化运行输出日志 +- Client 完整重构 +- Logger 修复Sync方法,更新其他组件日志 +- Config 合并实现方法 +- Render/BindProtobuf 无需proto文件进行编码 +- NewContextMessage 新增函数返回请求上下文消息,复用message。 +- middleware/gzip 使用自定义压缩函数,可以使用br压缩。 +- middleware/look 使用自定义data获取函数。 [2022年4月30日](https://github.com/eudore/eudore/tree/b80422e67f5c9907967e36e577d23220793a6c9c) - App和Context 生命周期管理 -- DataHandlerFunc 合并Bind Validate Filte Render +- DataHandlerFunc 合并Bind Validate Filte Render - Client 移入App组合 - Server 实现ServeConn方法 - ConvertTo 重构实现 @@ -31,9 +51,9 @@ Next - middleware/look 解析Accept Header为format值,模板内容优化。 - ConfigParseFunc ConfigParseFunc重构 - ResponseWriter WriteHeader将延时写入 -- contextBase 细节调整 +- contextBase 细节调整 - policy 增加401 -- httptest 修复响应对象并非读写 +- httptest 修复响应对象并非读写 [2021年8月31日](https://github.com/eudore/eudore/commit/627e6de1fa64c45873c70f86637efa2decc5763f) - Controller 简化内容保留ControllerAutoRoute。 diff --git a/README.md b/README.md index cc39395..486808a 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Eudore +[![Build Status](https://github.com/eudore/eudore/actions/workflows/action.yml/badge.svg)](https://github.com/eudore/eudore/actions/workflows/action.yml) [![godoc](https://godoc.org/github.com/eudore/eudore?status.svg)](https://godoc.org/github.com/eudore/eudore) [![go report card](https://goreportcard.com/badge/github.com/eudore/eudore)](https://goreportcard.com/report/github.com/eudore/eudore) [![codecov](https://codecov.io/gh/eudore/eudore/branch/master/graph/badge.svg)](https://codecov.io/gh/eudore/eudore) @@ -30,5 +31,4 @@ go get -v -u github.com/eudore/eudore - [godoc](https://godoc.org/github.com/eudore/eudore) - [example演示 100+](_example#example) - [wiki文档](https://github.com/eudore/eudore/wiki) -- [更新说明](CHANGELOG.md) -- [实践](https://github.com/eudore/website) +- [更新说明](CHANGELOG.md) \ No newline at end of file diff --git a/_example/README.md b/_example/README.md index 6825d27..1980382 100644 --- a/_example/README.md +++ b/_example/README.md @@ -13,11 +13,7 @@ go version go1.18.7 linux/amd64 coverage: 100.0% of statements in github.com/eud - [New](appNew.go) - [静态文件](appStatic.go) - [全局请求中间件](appMiddleware.go) - - [后台启动](appDaemon.go)(Alpha) - - [启动命令解析](appCommand.go)(Alpha) - - [监听代码自动编译重启](appNotify.go)(Alpha) - [自定义app](appExtend.go) - - [重新加载配置](appReload.go) - Config - [map存储配置](configMap.go) - [结构体存储配置](configEudore.go) @@ -37,6 +33,8 @@ go version go1.18.7 linux/amd64 coverage: 100.0% of statements in github.com/eud - [写入Elastic](loggerElastic.go) - [日志脱敏](loggerSensitive.go) - [logrus库适配](loggerLogrus.go) +- Client + - x - Server - [设置超时](serverStd.go) - [服务监听](serverListen.go) @@ -55,6 +53,11 @@ go version go1.18.7 linux/amd64 coverage: 100.0% of statements in github.com/eud - [路由器注册移除](routerDelete.go) - [路由器核心简化](routerCore.go) - [radix树](routerRadix.go) +- Controller + - [路由控制器](controllerAutoRoute.go) + - [控制器组合](controllerCompose.go) + - [控制器自定义参数](controllerParams.go) + - [控制器错误处理](controllerError.go) - Context - [Request Info](contextRequestInfo.go) - [Response Write](contextResponsWrite.go) @@ -76,7 +79,7 @@ go version go1.18.7 linux/amd64 coverage: 100.0% of statements in github.com/eud - [Send Template](contextRenderTemplate.go) - [文件上传](contextUpload.go) - [设置额外数据](contextValue.go) -- Context处理扩展 +- HandlerExtender - [默认处理](handlerDefault.go) - [处理ContextData扩展](handlerContextData.go) - [处理自定义函数类型](handlerFunc.go) @@ -88,11 +91,8 @@ go version go1.18.7 linux/amd64 coverage: 100.0% of statements in github.com/eud - [Rpc式map请求](handlerRpcMap.go) - [使用embed](handlerEmbed.go) - [使用jwt](handlerJwt.go) -- Controller - - [路由控制器](controllerAutoRoute.go) - - [控制器组合](controllerCompose.go) - - [控制器自定义参数](controllerParams.go) - - [控制器错误处理](controllerError.go) +- HandlerData + - FuncCreator - Middleware - [Admin中间件管理后台](middlewareAdmin.go) - [BasicAuth](middlewareBasicAuth.go) @@ -122,6 +122,11 @@ go version go1.18.7 linux/amd64 coverage: 100.0% of statements in github.com/eud - [RouterRewrite](middlewareRouterRewrite.go) - [Timeout请求超时](middlewareTimeout.go) - [自定义中间件处理函数](middlewareHandle.go) +- Daemon + - [后台启动](appDaemon.go) + - [命令管理进程](appCommand.go) + - [热重启](appRestart.go) + - [重新加载配置](appReload.go) - Policy(Alpha) - [Pbac](policyPbac.go) - [Rbac](policyRbac.go) @@ -140,10 +145,6 @@ go version go1.18.7 linux/amd64 coverage: 100.0% of statements in github.com/eud - Websocket - [使用github.com/gobwas/ws库](websocketGobwas.go) - [使用github.com/gorilla/websocket库](websocketGorilla.go) -- tool - - [转换对象成map](toolConvertMap.go) - - [对象转换](toolConvertTo.go) - - [基于路径读写对象](toolGetSet.go) - net/http - [中间件 黑名单](nethttpBlack.go) - [中间件 路径重写](nethttpRewrite.go) @@ -155,5 +156,6 @@ go version go1.18.7 linux/amd64 coverage: 100.0% of statements in github.com/eud - [http客户端简化实现](otherHttpClient.go) - [http服务端简化实现](otherHttpServer.go) - [http服务端简化](otherHttpServer2.go) + - [监听代码自动编译重启](otherNotify.go)(Alpha) diff --git a/_example/appCommand.go b/_example/appCommand.go index 26e780e..799a56d 100644 --- a/_example/appCommand.go +++ b/_example/appCommand.go @@ -11,7 +11,7 @@ go build -o server command包解析启动命令,支持start、daemon、status、stop、restart五个命令,需要定义command和pidfile两个配置参数。 通过向进程发送对应的系统信号实现对应的命令。 -该组件不支持win系统。 +该组件不支持windows系统。 */ import ( @@ -20,14 +20,20 @@ import ( "github.com/eudore/eudore" "github.com/eudore/eudore/daemon" + "github.com/eudore/eudore/middleware" ) func main() { app := eudore.NewApp() app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) - app.ParseOption(append(eudore.DefaultConfigAllParseFunc, daemon.NewParseCommand(app), NewParseLogger(app))) - app.SetValue(eudore.ContextKeyError, app.Parse()) + // append config parse + app.ParseOption( + daemon.NewParseDaemon(app), + NewParseLogger(app), + ) + app.Parse() + app.AddMiddleware(middleware.NewLoggerFunc(app)) app.GetFunc("/*", func(ctx eudore.Context) { ctx.WriteString("server daemon") }) @@ -35,7 +41,7 @@ func main() { go func() { select { case <-app.Done(): - case <-time.After(10 * time.Second): + case <-time.After(100 * time.Second): app.CancelFunc() } }() @@ -44,10 +50,10 @@ func main() { } func NewParseLogger(app *eudore.App) eudore.ConfigParseFunc { - return func(ctx context.Context, cnf eudore.Config) error { - app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerStd(&eudore.LoggerStdConfig{ - Std: true, - Path: "/tmp/daemon.log", + return func(context.Context, eudore.Config) error { + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + Path: "/tmp/daemon.log", })) return nil } diff --git a/_example/appDaemon.go b/_example/appDaemon.go index 54cd609..4ce5cb6 100644 --- a/_example/appDaemon.go +++ b/_example/appDaemon.go @@ -3,22 +3,21 @@ package main /* 通过Daemon()函数后台启动程序,也可以通过命令解析启动程序。 -当第一次启动时,使用os.Exec执行启动命令后台启动进程、关闭进程并附加环境变量,第二次启动时检测到环境变量即为后台启动,会忽略后台启动逻辑。然后执行正常启动。 +当第一次启动时,使用os.Exec执行启动命令后台启动进程、关闭进程并附加环境变量, +第二次启动时检测到环境变量即为后台启动,会忽略后台启动逻辑。然后执行正常启动。 -该组件不支持win系统。 +该组件不支持windows系统。 */ import ( - "fmt" - "os" - "os/exec" "time" "github.com/eudore/eudore" + "github.com/eudore/eudore/daemon" ) func main() { - Daemon() + daemon.StartDaemon() app := eudore.NewApp() app.GetFunc("/*", func(ctx eudore.Context) { @@ -35,17 +34,3 @@ func main() { app.Listen(":8088") app.Run() } - -// Daemon 函数直接后台启动程序。 -func Daemon(envs ...string) { - if eudore.GetStringBool(os.Getenv(eudore.EnvEudoreIsDaemon)) { - return - } - - cmd := exec.Command(os.Args[0], os.Args[1:]...) - cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%d", eudore.EnvEudoreIsDaemon, 1)) - cmd.Env = append(cmd.Env, envs...) - cmd.Stdout = os.Stdout - cmd.Start() - os.Exit(0) -} diff --git a/_example/appExtend.go b/_example/appExtend.go index 3b110cb..b401fe2 100644 --- a/_example/appExtend.go +++ b/_example/appExtend.go @@ -12,6 +12,7 @@ App是一个自定义的程序主体,可以额外组合需要的App对象和 import ( "database/sql" + "github.com/eudore/eudore" ) @@ -45,7 +46,7 @@ func NewApp() *App { App: eudore.NewApp(), Config: conf, } - app.SetValue(eudore.ContextKeyConfig, eudore.NewConfigStd(conf)) + app.SetValue(eudore.ContextKeyConfig, eudore.NewConfig(conf)) return app } diff --git a/_example/appHealth.go b/_example/appHealth.go new file mode 100644 index 0000000..d91eb73 --- /dev/null +++ b/_example/appHealth.go @@ -0,0 +1,37 @@ +package main + +/* +HandlerMetadata返回App全部All的Metadata() any方法数据, +Metadata前两个字段为Health和Name。 + +type Metadata struct { + Health bool `alias:"health" json:"health" xml:"health" yaml:"health"` + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` +} +*/ +import ( + "github.com/eudore/eudore" + "github.com/eudore/eudore/middleware" +) + +func main() { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + StdColor: true, + HookMeta: true, + })) + app.SetValue(eudore.ContextKeyHandlerExtender, eudore.NewHandlerExtender()) + app.SetValue(eudore.ContextKeyFuncCreator, eudore.NewFuncCreator()) + app.SetValue(eudore.ContextKeyRender, eudore.RenderJSON) + app.SetValue(eudore.ContextKeyContextPool, eudore.NewContextBasePool(app)) + app.AddMiddleware(middleware.NewLoggerFunc(app)) + app.AddMiddleware(middleware.NewRecoverFunc()) + app.GetFunc("/health", eudore.HandlerMetadata) + app.GetFunc("/panic", func(ctx eudore.Context) { + panic(ctx) + }) + + app.Listen(":8087") + app.Run() +} diff --git a/_example/appNew.go b/_example/appNew.go index c8a88ba..cf6e045 100644 --- a/_example/appNew.go +++ b/_example/appNew.go @@ -6,10 +6,12 @@ eudore.App对象的简单组装各类对象,实现Value/SetValue、Listen和Ru import ( "github.com/eudore/eudore" + "github.com/eudore/eudore/middleware" ) func main() { app := eudore.NewApp() + app.AddMiddleware(middleware.NewLoggerFunc(app, "route")) app.AnyFunc("/*", func(ctx eudore.Context) { ctx.WriteString("hello eudore") }) diff --git a/_example/appNew2.go b/_example/appNew2.go new file mode 100644 index 0000000..baa9dec --- /dev/null +++ b/_example/appNew2.go @@ -0,0 +1,116 @@ +package main + +/* +eudore.App对象的简单组装各类对象,实现Value/SetValue、Listen和Run方法。 +*/ + +import ( + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/eudore/eudore" + _ "github.com/eudore/eudore/daemon" + "github.com/eudore/eudore/middleware" +) + +func main() { + eudore.DefaultLoggerFormatterFormatTime = "none" + app := eudore.NewApp() + defer app.Run() + if app.Parse() != nil { + return + } + + addr, _ := url.Parse("http://127.0.0.1:30000/") + proxy := httputil.NewSingleHostReverseProxy(addr) + app.AddMiddleware("global", func(ctx eudore.Context) { + req := ctx.Request() + if strings.Contains(req.Host, ":8086") { + req.Host = "godoc.kube-public.eudore.cn:30000" + proxy.ServeHTTP(ctx.Response(), req) + ctx.End() + } + }) + + app.SetValue(eudore.ContextKeyHandlerExtender, eudore.NewHandlerExtender()) + app.SetValue(eudore.ContextKeyFuncCreator, eudore.NewFuncCreatorExpr()) + app.AddMiddleware( + middleware.NewHeaderFilteFunc(nil, nil), + middleware.NewRecoverFunc(), + middleware.NewLoggerFunc(app, "route"), + middleware.NewBasicAuthFunc(map[string]string{"eudore": "11"}), + middleware.NewCacheFunc(), + middleware.NewCompressMixinsFunc(nil), + middleware.NewDumpFunc(app.Group("/eudore/debug")), + ) + app.AddHandler("404", "", eudore.HandlerRouter404) + app.GetFunc("/bind", func(ctx eudore.Context) { + type Config struct { + Name []int + } + c := &Config{} + ctx.Bind(c) + ctx.Render(c) + }) + app.GetFunc("/fatal", func(ctx eudore.Context) { + ctx.Fatal(3) + }) + + app.AddController(&eudore.ControllerAutoRoute{}) + debug := app.Group("/eudore/debug") + debug.GetFunc("/admin/ui", middleware.HandlerAdmin) + debug.AnyFunc("/pprof/*", middleware.HandlerPprof) + debug.GetFunc("/look/*", middleware.NewLookFunc(app)) + debug.GetFunc("/meta/*", eudore.HandlerMetadata) + debug.GetFunc("/src/* autoindex=true", eudore.NewHandlerStatic(".")) + debug.GetFunc("/panic", func(ctx eudore.Context) { + panic(ctx.Path()) + }) + + app.Listen(":8086") + app.Listen(":8087") + + app.GetFunc("/values", func(ctx eudore.Context) { + ctx.Debug(ctx.GetHeader(eudore.HeaderContentType)) + ctx.FormValue("name") + ctx.WriteString("string") + }) + app.GetFunc("/text", func(ctx eudore.Context) { + ctx.Render("name") + }) + app.GetFunc("/log", func(ctx eudore.Context) { + ctx.FormFiles() + ctx.Info("info1") + ctx.WithField("l", 2).Info("info2") + ctx.WithField("depth", "stack").Info("info2") + }) + client := app.WithClient( + eudore.NewClientOptionBasicauth("eudore", "11"), + &eudore.ClientTrace{}, + time.Second*3, + ) + client.NewRequest(nil, "GET", "/values?d=1") + + app.Warning("-----------------") + client.NewRequest(nil, "GET", "/log?d=1") + app.Info("info") + app.AddHandlerExtend(999) + app.AddHandlerExtend(eudore.NewHandlerStringer) + app.Router.AddMiddleware(eudore.HandlerEmpty) + app.AddMiddleware(eudore.HandlerEmpty) + app.AddController(myController{}) + app.GetFunc("/xx", eudore.HandlerEmpty) + app.Warning("-----------------") + + // app.CancelFunc() +} + +type myController struct { + eudore.ControllerAutoRoute +} + +func (myController) Get() { + +} diff --git a/_example/appNotify.go b/_example/appNotify.go deleted file mode 100644 index 5c370b2..0000000 --- a/_example/appNotify.go +++ /dev/null @@ -1,38 +0,0 @@ -package main - -/* -先app.Config设置notify配置,然后启动notify。 -如果是notify的程序可以通过环境变量eudore.EnvEudoreIsNotify检测。 -当程序启动时会如果eudore.EnvEudoreIsNotify不存在,则使用notify开始监听阻塞app后续初始化,否在就忽略notify然后进行正常app启动。 - -实现原理基于fsnotify检测目录内go文件变化,然后执行编译命令,如果编译成功就kill原进程并执行启动命令。 - -其他类似工具:air -*/ - -import ( - "github.com/eudore/eudore" - "github.com/eudore/eudore/component/notify" -) - -func main() { - app := eudore.NewApp() - - // 设置编译命令、启动命令、监听目录, 如果是启动的notify,则阻塞主进程等待退出。 - app.Config.Set("component.notify.buildcmd", "go build -o server appNotify.go") - app.Config.Set("component.notify.startcmd", "./server") - app.Config.Set("component.notify.watchdir", ".") - n := notify.NewNotify(app) - if n.IsRun() { - // 启动日志输出 跳过后续初始化 - go app.Run() - n.Run() - return - } - - app.AnyFunc("/*", func(ctx eudore.Context) { - ctx.WriteString("hello eudore") - }) - app.Listen(":8088") - app.Run() -} diff --git a/_example/appReload.go b/_example/appReload.go index 90c61d2..7065357 100644 --- a/_example/appReload.go +++ b/_example/appReload.go @@ -1,9 +1,18 @@ package main +/* +添加配置解析NewParseDaemon函数,注册daemon.Signal管理信号。 +然后给信号10注册Reload函数。 +*/ + import ( + "context" "fmt" - "github.com/eudore/eudore" + "syscall" "time" + + "github.com/eudore/eudore" + "github.com/eudore/eudore/daemon" ) type AppReload struct { @@ -12,40 +21,50 @@ type AppReload struct { } type ConfigReload struct { - Name string `alias:"name" json:"name"` - Time string `alias:"time" json:"time"` + Workdir string `json:"workdir" alias:"workdir"` + Command string `json:"command" alias:"command"` + Pidfile string `json:"pidfile" alias:"pidfile"` + Name string `alias:"name" json:"name"` + Time string `alias:"time" json:"time"` } func main() { - app := NewAppReload() - app.Init() - // 访问reload 触发重新加载 - app.AnyFunc("/reload", app.Init) - - app.Listen(":8088") - app.Run() -} - -func NewAppReload() *AppReload { - conf := &ConfigReload{Name: "eudore"} + conf := &ConfigReload{ + Name: "eudore", + Time: time.Now().String(), + } app := &AppReload{ App: eudore.NewApp(), ConfigReload: conf, } // 使用读写路由核心,允许并发增删路由规则。 - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(eudore.NewRouterCoreLock(nil))) - app.SetValue(eudore.ContextKeyConfig, eudore.NewConfigStd(conf)) - return app -} + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(eudore.NewRouterCoreLock(nil))) + app.SetValue(eudore.ContextKeyConfig, eudore.NewConfig(conf)) + // 添加daemon对象 + app.ParseOption(daemon.NewParseDaemon(app.App)) + app.Parse() + if app.Err() != nil { + return + } -// Init 方法加载配置并注册路由 -func (app *AppReload) Init() error { - err := app.Parse() - if err != nil { - return err + // 注册系统信号reload + d, ok := app.Value(eudore.ContextKeyDaemonSignal).(*daemon.Signal) + if ok { + d.Register(syscall.Signal(0x0a), app.Reload) } + // 注册api触发reload + app.AnyFunc("/reload", app.Reload) + + app.AddController(NewUserReloadController(app)) + app.Listen(":8086") + app.Run() +} + +// Reload 方法加载配置并注册路由 +func (app *AppReload) Reload(context.Context) error { app.Time = time.Now().String() - return app.AddController(NewUserReloadController(app)) + app.Debug("app reload config") + return nil } type UserReloadController struct { @@ -58,7 +77,7 @@ func NewUserReloadController(app *AppReload) eudore.Controller { return &UserReloadController{Name: app.ConfigReload.Name, Config: app.App.Config} } -func (ctl UserReloadController) Any(ctx eudore.Context) interface{} { +func (ctl UserReloadController) Any(ctx eudore.Context) { // 使用属性或Get获取数据,Get方法带锁。 - return fmt.Sprintf("name is %s at %v", ctl.Name, ctl.Config.Get("time")) + ctx.WriteString(fmt.Sprintf("name is %s at %v", ctl.Name, ctl.Config.Get("time"))) } diff --git a/_example/appRestart.go b/_example/appRestart.go new file mode 100644 index 0000000..c4d3848 --- /dev/null +++ b/_example/appRestart.go @@ -0,0 +1,60 @@ +package main + +/* +按照命令执行程序。 + +go build -o server +./server --command=daemon +./server --command=status +./server --command=stop +./server --command=status + +command包解析启动命令,支持start、daemon、status、stop、restart五个命令,需要定义command和pidfile两个配置参数。 +通过向进程发送对应的系统信号实现对应的命令。 +该组件不支持windows系统。 +*/ + +import ( + "context" + "time" + + "github.com/eudore/eudore" + "github.com/eudore/eudore/daemon" +) + +func main() { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) + app.ParseOption( + daemon.NewParseDaemon(app), + NewParseLogger(app), + daemon.NewParseRestart(), + ) + app.Parse() + defer app.Run() + + if app.Err() == nil { + app.GetFunc("/*", func(ctx eudore.Context) { + ctx.WriteString("server daemon") + }) + app.Listen(":8087") + + go func() { + select { + case <-app.Done(): + case <-time.After(600 * time.Second): + app.CancelFunc() + } + }() + } +} + +func NewParseLogger(app *eudore.App) eudore.ConfigParseFunc { + return func(context.Context, eudore.Config) error { + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + Path: "/tmp/daemon.log", + })) + return nil + } +} diff --git a/_example/appStatic.go b/_example/appStatic.go index fd07456..0214c46 100644 --- a/_example/appStatic.go +++ b/_example/appStatic.go @@ -1,17 +1,34 @@ package main +/* +默认HandlerExtender注册扩展具有NewHandlerEmbed和NewHandlerHTTPFileSystem函数。, +可以传递emebd.FS http.FileSystem类型处理者。 + +NewFileSystems函数可以将string io/fs.FS(embed.FS) net/http.FileSystem(http.Dir)转换成eudore.HandlerFunc + +eudore.NewHandlerStatic(".") +eudore.NewHandlerEmbed(root) +http.Dir(".") +eudore.NewFileSystems(".", root) +*/ + import ( + "embed" + "net/http" + "github.com/eudore/eudore" ) +//go:embed *.go +var root embed.FS + func main() { app := eudore.NewApp() // 添加静态文件处理 - app.GetFunc("/js/*", NewStaticHandlerWithCache("", "public")) - // WriteFile 调用http.ServeFile实现,可以额外添加etag计算等逻辑,文件路径拼接需要注意清理。 - app.GetFunc("/css/*", func(ctx eudore.Context) { - ctx.WriteFile("static" + ctx.Path()) - }) + app.GetFunc("/src/*", eudore.NewFileSystems(".", root)) + app.GetFunc("/js/*", eudore.NewHandlerStatic(".")) + app.GetFunc("/js/*", http.Dir(".")) + app.GetFunc("/css/*", eudore.NewHandlerEmbed(root)) app.GetFunc("/*", func(ctx eudore.Context) { ctx.WriteString("hello eudore") }) @@ -22,7 +39,7 @@ func main() { // NewStaticHandlerWithCache 函数指定NewStaticHandler的缓存策略,默认为no-cache func NewStaticHandlerWithCache(path, policy string) eudore.HandlerFunc { - fn := eudore.NewStaticHandler("", path) + fn := eudore.NewHandlerStatic(path) return func(ctx eudore.Context) { ctx.SetHeader("Cache-Control", policy) fn(ctx) diff --git a/_example/app_test.go b/_example/app_test.go index 69f0e23..583cfa0 100644 --- a/_example/app_test.go +++ b/_example/app_test.go @@ -8,7 +8,7 @@ import ( ) func init() { - eudore.DefaultLoggerTimeFormat = "none" + eudore.DefaultLoggerFormatterFormatTime = "none" } func TestAppRun(*testing.T) { @@ -22,16 +22,12 @@ func TestAppRun(*testing.T) { ctx.WriteString("hello eudore") }) - app.NewRequest(nil, "GET", "/hello", - eudore.NewClientBodyString("trace"), - eudore.NewClientCheckStatus(200), - ) - app.Value(eudore.ContextKeyLogger) app.Value(eudore.ContextKeyConfig) app.Value(eudore.ContextKeyDatabase) app.Value(eudore.ContextKeyClient) app.Value(eudore.ContextKeyRouter) + app.Value(eudore.ContextKeyAppKeys) app.SetValue(eudore.ContextKeyError, "stop app") app.CancelFunc() diff --git a/_example/benchBuffer_test.go b/_example/benchBuffer_test.go new file mode 100644 index 0000000..00ec366 --- /dev/null +++ b/_example/benchBuffer_test.go @@ -0,0 +1,75 @@ +package eudore_test + +import ( + "bytes" + "testing" + "unsafe" +) + +func BenchmarkAppendBuff(b *testing.B) { + buf := bytes.NewBuffer(make([]byte, 2048)) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf.Reset() + for _, s := range strs { + buf.WriteString(s) + } + } +} + +func BenchmarkAppendCopy(b *testing.B) { + en := &encoder{make([]byte, 2048)} + b.ReportAllocs() + for i := 0; i < b.N; i++ { + en.data = en.data[0:0] + for _, s := range strs { + en.WriteString(s) + } + } +} + +type encoder struct { + data []byte +} + +func (en *encoder) WriteString(s string) { + b := *(*[]byte)(unsafe.Pointer(&struct { + string + Cap int + }{s, len(s)})) + en.data = append(en.data, b...) +} + +var strs = []string{ + "&{Context:context.Background.WithCancel", + "CancelFunc:0x4f7de0", + "Logger:0xc00013c5a0", + "Config:0xc000080730", + "Database:", + "Client:0xc000153da0", + "Server:0xc000080780", + "Ro", + "uter:0xc0001421c0", + "GetWarp:0x782160", + "HandlerFuncs:[github.com/eudore/eudore.(*App).serveContext-fm]", + "ContextPool:0xc0001650b0", + "CancelError:", + "cancelMutex:{state:0", + "sema:0}", + "Values:[bind", + "0x79fac0", + "render", + "0x7a0580", + "templdate", + "0xc000158060", + "handler-extender", + "0xc000160cc0", + "func-creator", + "0xc0001628c0]}&{RouterCore:0xc0000f5480", + "HandlerExtender:0xc0000b9800", + "Middlewares:0xc00008da90", + "GroupParams:route=", + "Logger:0xc000154270", + "LoggerKind:all", + "Meta:0xc00013c640}", +} diff --git a/_example/benchFuncRun_test.go b/_example/benchFuncRun_test.go new file mode 100644 index 0000000..bba112e --- /dev/null +++ b/_example/benchFuncRun_test.go @@ -0,0 +1,87 @@ +package eudore_test + +/* +goos: linux +goarch: amd64 +cpu: Intel(R) Xeon(R) Gold 6133 CPU @ 2.50GHz +BenchmarkReflectFuncString-2 3243481 401.2 ns/op 25 B/op 2 allocs/op +BenchmarkReflectFuncInt-2 2878328 376.2 ns/op 25 B/op 2 allocs/op +BenchmarkReflectFuncInterface-2 2321782 476.6 ns/op 41 B/op 3 allocs/op +BenchmarkRunFuncString-2 630790024 2.236 ns/op 0 B/op 0 allocs/op +BenchmarkRunFuncInt-2 511247667 2.575 ns/op 0 B/op 0 allocs/op +BenchmarkRunFuncInterface-2 615269007 1.986 ns/op 0 B/op 0 allocs/op +BenchmarkRunFuncStringKind-2 541799686 2.139 ns/op 0 B/op 0 allocs/op +PASS +ok command-line-arguments 10.828s +*/ + +import ( + "reflect" + "testing" +) + +var ( + ReflectString = reflect.ValueOf(stringIsZero) + ReflectInt = reflect.ValueOf(intIsZero) + ReflectInterface = reflect.ValueOf(interfaceIsZero) + FuncString = stringIsZero + FuncInt = intIsZero + FuncInterface = interfaceIsZero +) + +func stringIsZero(string) bool { return true } +func intIsZero(int) bool { return true } +func interfaceIsZero(interface{}) bool { return true } + +func BenchmarkReflectFuncString(b *testing.B) { + args := []reflect.Value{reflect.ValueOf("0")} + b.ReportAllocs() + for i := 0; i < b.N; i++ { + ReflectString.Call(args) + } +} + +func BenchmarkReflectFuncInt(b *testing.B) { + args := []reflect.Value{reflect.ValueOf(0)} + b.ReportAllocs() + for i := 0; i < b.N; i++ { + ReflectInt.Call(args) + } +} + +func BenchmarkReflectFuncInterface(b *testing.B) { + args := []reflect.Value{reflect.ValueOf(0)} + b.ReportAllocs() + for i := 0; i < b.N; i++ { + ReflectInterface.Call(args) + } +} + +func BenchmarkRunFuncString(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + FuncString("0") + } +} + +func BenchmarkRunFuncInt(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + FuncInt(0) + } +} + +func BenchmarkRunFuncInterface(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + FuncInterface(0) + } +} + +func BenchmarkRunFuncStringKind(b *testing.B) { + var fn interface{} = stringIsZero + b.ReportAllocs() + for i := 0; i < b.N; i++ { + fn.(func(string) bool)("0") + } +} diff --git a/_example/benchFunc_test.go b/_example/benchFunc_test.go index b1ede09..fe1d035 100644 --- a/_example/benchFunc_test.go +++ b/_example/benchFunc_test.go @@ -19,8 +19,8 @@ func (stmt *StmtSelect) Init(DatabaseContext) error { _ = stmt.Name return nil } -func (stmt *StmtSelect) Build(DatabaseContext) { +func (stmt *StmtSelect) Build(DatabaseContext) { } func BenchmarkFuncReflect(b *testing.B) { @@ -29,7 +29,7 @@ func BenchmarkFuncReflect(b *testing.B) { stmt.Build(ctx) }) var ctx DatabaseContext = 0 - var stmt = &StmtSelect{"eudore"} + stmt := &StmtSelect{"eudore"} b.ReportAllocs() for i := 0; i < b.N; i++ { reflect.TypeOf(stmt) @@ -44,7 +44,7 @@ func BenchmarkFuncAassertions(b *testing.B) { stmtSelect.Build(ctx) } var ctx DatabaseContext = 0 - var stmt = &StmtSelect{"eudore"} + stmt := &StmtSelect{"eudore"} b.ReportAllocs() for i := 0; i < b.N; i++ { reflect.TypeOf(stmt) diff --git a/_example/benchHeader_test.go b/_example/benchHeader_test.go index 54bf0cd..7aeda27 100644 --- a/_example/benchHeader_test.go +++ b/_example/benchHeader_test.go @@ -20,6 +20,7 @@ func BenchmarkHead1(b *testing.B) { textproto.CanonicalMIMEHeaderKey("Content-Disposition") } } + func BenchmarkHead2(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { diff --git a/_example/benchName_test.go b/_example/benchName_test.go new file mode 100644 index 0000000..8f61855 --- /dev/null +++ b/_example/benchName_test.go @@ -0,0 +1,51 @@ +package eudore_test + +/* +goos: linux +goarch: amd64 +cpu: Intel(R) Xeon(R) Gold 6133 CPU @ 2.50GHz +BenchmarkReflectTypeNameString-2 48612525 27.50 ns/op 0 B/op 0 allocs/op +BenchmarkReflectTypeNameInt-2 49561389 25.95 ns/op 0 B/op 0 allocs/op +BenchmarkReflectTypeEqualString-2 211048983 4.910 ns/op 0 B/op 0 allocs/op +BenchmarkReflectTypeEqualInt-2 215496802 5.405 ns/op 0 B/op 0 allocs/op +PASS +ok command-line-arguments 6.036s +*/ + +import ( + "reflect" + "testing" +) + +const ( + String = "string" + Int = 0 +) + +func BenchmarkReflectTypeNameString(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = reflect.TypeOf(String).String() == reflect.TypeOf(String).String() + } +} + +func BenchmarkReflectTypeNameInt(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = reflect.TypeOf(Int).String() == reflect.TypeOf(Int).String() + } +} + +func BenchmarkReflectTypeEqualString(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = reflect.TypeOf(String) == reflect.TypeOf(String) + } +} + +func BenchmarkReflectTypeEqualInt(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = reflect.TypeOf(Int) == reflect.TypeOf(Int) + } +} diff --git a/_example/client_test.go b/_example/client_test.go index ec5039f..e5952da 100644 --- a/_example/client_test.go +++ b/_example/client_test.go @@ -5,65 +5,194 @@ import ( "context" "crypto/tls" "fmt" - "io/ioutil" + "io" + "net" "net/http" + "net/http/cookiejar" "net/url" + "strings" "testing" "time" "github.com/eudore/eudore" + "github.com/eudore/eudore/middleware" + "golang.org/x/net/http2" ) +func TestClientOptions(t *testing.T) { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyClient, app.WithClient( + &eudore.ClientOption{ + Values: url.Values{"debug": {"1"}}, + Header: http.Header{"X-Client": {"eudore"}}, + Trace: &eudore.ClientTrace{}, + }, + eudore.NewClientOptionUserAgent("Client-Eudore"), + eudore.NewClientOptionHost("eudore.cn"), + eudore.Cookie{Name: "name", Value: "eudore"}, + )) + + { + app.GetFunc("/jar", middleware.NewLoggerFunc(app), func(ctx eudore.Context) { + ctx.SetCookieValue("count", ctx.GetCookie("count")+"1", 0) + ctx.Debug(ctx.GetCookie("count")) + }) + jar, _ := cookiejar.New(nil) + client := app.WithClient(jar) + for i := 0; i < 5; i++ { + client.NewRequest(nil, "", "/jar") + } + } + { + app.GetFunc("/timeout", func(ctx eudore.Context) { + time.Sleep(time.Microsecond * 2) + }) + client := app.WithClient(time.Microsecond) + app.Info(client.NewRequest(nil, "", "/timeout")) + } + { + app.ListenTLS(":8089", "", "") + client := app.WithClient(&http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }) + app.Debug(app.NewRequest(nil, "", "https://localhost:8089/app")) + + trace := &eudore.ClientTrace{} + app.WithField("trace", trace).Debug(client.NewRequest(nil, "", "https://localhost:8089/client", trace)) + } + + { + app.GetClient().Get("/") + } + + { + eudore.NewClient().WithClient( + &http2.Transport{ + AllowHTTP: true, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + return tls.Dial(network, addr, cfg) + }, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + &http.Transport{}, + http.DefaultClient, + ) + } + + app.CancelFunc() + app.Run() +} + func TestClientRequest(t *testing.T) { + eudore.DefaultClinetLoggerLevel = eudore.LoggerDebug + defer func() { + eudore.DefaultClinetLoggerLevel = eudore.LoggerError + }() + app := eudore.NewApp() + app.AddMiddleware(middleware.NewRequestIDFunc(nil)) app.AnyFunc("/*", func(ctx eudore.Context) { ctx.Info("server body:", string(ctx.Body())) ctx.Write(ctx.Body()) }) app.AnyFunc("/ctx", func(ctx eudore.Context) { client := ctx.Value(eudore.ContextKeyClient).(eudore.Client).WithClient( - eudore.NewClientHeader(eudore.HeaderAuthorization, ctx.GetHeader(eudore.HeaderAuthorization)), + http.Header{ + eudore.HeaderAuthorization: {ctx.GetHeader(eudore.HeaderAuthorization)}, + }, ) ctx.SetValue(eudore.ContextKeyClient, client) ctx.NewRequest("GET", "/") }) - client := app.GetClient() - tp, ok := client.Transport.(*http.Transport) - if ok { - tp.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } + client := app.WithClient(&http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }) - app.Debug(app.NewRequest(nil, "GET", "/", - "eudore body", - url.Values{"name": []string{"eudore"}}, - http.Header{"Client": []string{"eudore-client"}}, + client.NewRequest(nil, "GET", "/", + strings.NewReader("eudore body"), + context.Background(), + time.Second, + url.Values{"name": {"eudore"}}, + http.Header{"Cookie": {"name=eudore-client"}}, &http.Cookie{Name: "name", Value: "eudore"}, - func(*http.Request) {}, - func(*http.Response) error { return nil }, - eudore.NewClientQuery("name", "eudore"), - eudore.NewClientQuerys(url.Values{"state": []string{"active"}}), - eudore.NewClientHeader("Client", "eudore"), - eudore.NewClientHeaders(http.Header{"Accept": []string{"application/json"}}), - eudore.NewClientCookie("id", "c6e2ada8-8715-465b-af25-f992723b5b0a"), - eudore.NewClientBasicAuth("eudore", "pass"), - eudore.NewClientTrace(), - eudore.NewClientDumpBody(), - )) - app.Debug(app.NewRequest(nil, "GET", "/", []byte("eudore bytes"))) - app.Debug(app.NewRequest(nil, "GET", "/", bytes.NewBufferString("eudore buffer"))) - app.Debug(app.NewRequest(nil, "GET", "/", ioutil.NopCloser(bytes.NewBufferString("eudore buffer")))) - app.Debug(app.NewRequest(nil, "GET", "\u007f")) - app.Debug(app.NewRequest(nil, "GET", "")) - app.Debug(app.NewRequest(nil, "GET", "/", - func(*http.Response) error { return fmt.Errorf("eudore client test error") }, - )) + eudore.Cookie{Name: "name", Value: "eudore"}, + eudore.Cookie{Name: "key1", Value: "key ,space"}, + eudore.Cookie{Name: "key2", Value: "key\x03invalid"}, + &eudore.ClientTrace{}, + &eudore.ClientOption{}, + ) app.Debug(app.NewRequest(nil, "GET", "/ctx")) + app.Debug(app.NewRequest(nil, "GET", "/ctx\x00")) + app.Debug(app.NewRequest(nil, "LOCK", "/ctx")) + + app.CancelFunc() + app.Run() +} + +func TestClientAuthorization(t *testing.T) { + digest := []string{ + `Digest realm="digest@eudore.cn", algorithm=MD5, nonce="H4GiTo0v", qop="auth, auth-int", opaque="CUYo5tdS"`, + `Digest realm="digest@eudore.cn", algorithm=MD5, nonce="H4GiTo0v", opaque="CUYo5tdS",qop="auth-int"`, + `Digest realm="digest@eudore.cn", algorithm=MD5, nonce="H4GiTo0v"`, + `Digest realm="digest@eudore.cn", algorithm=MD5-SESS, nonce="H4GiTo0v"`, + `Digest realm="digest@eudore.cn", algorithm=SHA-256, nonce="H4GiTo0v"`, + `Basic realm="digest@eudore.cn", algorithm=MD5, nonce="H4GiTo0v", opaque="CUYo5tdS"`, + `Digest realm="digest@eudore.cn", algorithm, nonce="H4GiTo0v", opaque="CUYo5tdS"`, + `Digest realm="digest@eudore.cn", algorithm=MD5, nonce="H4GiTo0v", cnonce="H4GiTo0v", opaque="CUYo5tdS"`, + `Digest realm="digest@eudore.cn", algorithm=RCR32, nonce="H4GiTo0v", opaque="CUYo5tdS"`, + `Digest realm="digest@eudore.cn", algorithm=MD5, nonce="H4GiTo0v", opaque="CUYo5tdS", qop=int`, + } + + app := eudore.NewApp() + app.AddMiddleware(middleware.NewLoggerFunc(app)) + app.GetFunc("/500", func(ctx eudore.Context) { + ctx.WriteHeader(eudore.StatusInternalServerError) + }) + app.GetFunc("/auth", func(ctx eudore.Context) { + ctx.Debug(ctx.GetHeader(eudore.HeaderAuthorization)) + }) + app.GetFunc("/digest", func(ctx eudore.Context) { + if ctx.GetHeader(eudore.HeaderAuthorization) == "" { + ctx.WriteHeader(eudore.StatusUnauthorized) + ctx.SetHeader(eudore.HeaderWWWAuthenticate, digest[eudore.GetAnyByString(ctx.GetQuery("d"), 0)]) + } else { + ctx.Debug(ctx.GetHeader(eudore.HeaderAuthorization)) + } + }) + + app.NewRequest(nil, "", "/auth", eudore.NewClientOptionBearer("Bearer .eyJ1c2VyX25hbWUiOiJHdWVzdCIsImV4cGlyYXRpb24iOjEwNDEzNzkyMDAwfQ.vNTXrJNVqRLLY01w6weQWMRo_HDeBeVpX4HZtVfYUBY")) + app.NewRequest(nil, "", "/auth", eudore.NewClientOptionBasicauth("Guest", "")) + + client := app.WithClient(eudore.NewClientRetryDigest("Guest", "Guest")) + for i := range digest { + client.NewRequest(nil, "", "/digest", url.Values{"d": {fmt.Sprint(i)}}) + } + client.NewRequest(nil, "", "/digest?d=1", strings.NewReader("digest body")) + client.NewRequest(nil, "", "/500") + client.NewRequest(nil, "", "/digest?d=1", io.NopCloser(strings.NewReader("digest body"))) + + form := eudore.NewClientBodyForm(nil) + form.AddFile("file", "name", strings.NewReader("file bodt")) + client.NewRequest(nil, "", "/digest?d=1", form) app.CancelFunc() app.Run() } +func TestClientRetry(t *testing.T) { + app := eudore.NewApp() + app.AddMiddleware(middleware.NewLoggerFunc(app)) + app.GetFunc("/502", func(ctx eudore.Context) { + ctx.WriteHeader(eudore.StatusBadGateway) + }) + client := app.WithClient(eudore.NewClientRetryNetwork(1)) + + client.NewRequest(nil, "", "/502", time.Second/10) + app.CancelFunc() + app.Run() +} + func TestClientRequestBody(t *testing.T) { type Body struct { Name string @@ -78,34 +207,37 @@ func TestClientRequestBody(t *testing.T) { ctx.Write(ctx.Body()) }) - app.Debug(app.NewRequest(nil, "GET", "/body/string", eudore.NewClientBodyString("eudore body string"))) - app.Debug(app.NewRequest(nil, "GET", "/body/json", Body{"eudor"})) - app.Debug(app.NewRequest(nil, "GET", "/body/jsonstruct", eudore.NewClientBodyJSON(struct{ Name string }{"eudore"}))) + app.Debug(app.NewRequest(nil, "GET", "/body/string", strings.NewReader("eudore body string"))) + app.Debug(app.NewRequest(nil, "GET", "/body/json", eudore.NewClientBodyJSON(nil))) + app.Debug(app.NewRequest(nil, "GET", "/body/jsonstruct", eudore.NewClientBodyJSON(Body{"eudor"}))) app.Debug(app.NewRequest(nil, "GET", "/body/jsonmap", eudore.NewClientBodyJSON(map[string]interface{}{"name": "eudore"}))) - app.Debug(app.NewRequest(nil, "GET", "/body/jsonvalue", eudore.NewClientBodyJSONValue("name", "eudore"))) - app.Debug(app.NewRequest(nil, "GET", "/bdoy/formvalue", - eudore.NewClientBodyFormValue("name", "eudore"), - eudore.NewClientBodyFormValues(map[string]string{"server": "eudore"}), - )) - app.Debug(app.NewRequest(nil, "GET", "/body/formfile", - eudore.NewClientBodyFormFile("file", "string.txt", "file string"), - eudore.NewClientBodyFormFile("file", "bytes.txt", []byte("file bytes")), - eudore.NewClientBodyFormFile("file", "buffer.txt", bytes.NewBufferString("file buffer")), - eudore.NewClientBodyFormFile("file", "rc.txt", ioutil.NopCloser(bytes.NewBufferString("file rc"))), - eudore.NewClientBodyFormFile("file", "none.txt", nil), - eudore.NewClientBodyFormLocalFile("file", "", "appNew.go"), - )) - app.Debug(app.NewRequest(nil, "PUT", "/body/json", eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationJSON), eudore.NewClientBody(Body{"eudore"}))) - app.Debug(app.NewRequest(nil, "PUT", "/body/json", eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationJSON), eudore.NewClientBody([]Body{{"eudore"}}))) - app.Debug(app.NewRequest(nil, "PUT", "/body/xml", eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationXML), eudore.NewClientBody(Body{"eudore"}))) - app.Debug(app.NewRequest(nil, "PUT", "/body/pb", eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationProtobuf), eudore.NewClientBody(&Body{"eudore"}))) - app.Debug(app.NewRequest(nil, "PUT", "/body/pb", eudore.NewClientHeader(eudore.HeaderContentType, "pb"), eudore.NewClientBody(Body{"eudore"}))) - - app.Debug(app.NewRequest(nil, "PUT", "/redirect", ioutil.NopCloser(bytes.NewBufferString("buffer rc")))) - app.Debug(app.NewRequest(nil, "PUT", "/redirect", eudore.NewClientBody(Body{"eudore"}))) - app.Debug(app.NewRequest(nil, "PUT", "/redirect", eudore.NewClientBodyString("eudore body string"))) - app.Debug(app.NewRequest(nil, "PUT", "/redirect", eudore.NewClientBodyJSONValue("name", "eudore"))) - app.Debug(app.NewRequest(nil, "PUT", "/redirect", eudore.NewClientBodyFormValue("name", "eudore"))) + app.Debug(app.NewRequest(nil, "GET", "/body/xml", eudore.NewClientBodyXML(Body{"eudor"}))) + app.Debug(app.NewRequest(nil, "GET", "/body/protobuf", eudore.NewClientBodyProtobuf(Body{"eudor"}))) + app.Debug(app.NewRequest(nil, "GET", "/body/form", eudore.NewClientBodyForm(url.Values{ + "name": {"eudore"}, + }))) + bodyForm := eudore.NewClientBodyForm(nil) + bodyForm.AddValue("name", "eudore") + bodyForm.AddFile("file", "bytes.txt", []byte("file bytes")) + bodyForm.AddFile("file", "buffer.txt", bytes.NewBufferString("file buffer")) + bodyForm.AddFile("file", "rc.txt", io.NopCloser(bytes.NewBufferString("file rc"))) + bodyForm.AddFile("file", "none.txt", nil) + bodyForm.AddFile("file", "", "appNew.go") + bodyForm.Close() + app.Debug(app.NewRequest(nil, "GET", "/body/formfile", bodyForm)) + + bodyForm = eudore.NewClientBodyForm(nil) + bodyForm.AddValue("name", "eudore") + bodyJSON := eudore.NewClientBodyJSON(nil) + bodyJSON.AddValue("name", "eudore") + bodyJSON.AddFile("file", "", "appNew.go") + bodyJSON = eudore.NewClientBodyJSON(&Body{}) + bodyJSON.AddValue("name", "eudore") + bodyJSON.Close() + + app.Debug(app.NewRequest(nil, "PUT", "/redirect", strings.NewReader("eudore body string"))) + app.Debug(app.NewRequest(nil, "PUT", "/redirect", bodyForm)) + app.Debug(app.NewRequest(nil, "PUT", "/redirect", bodyJSON)) app.CancelFunc() app.Run() @@ -113,99 +245,66 @@ func TestClientRequestBody(t *testing.T) { func TestClientResponse(t *testing.T) { app := eudore.NewApp() - app.AnyFunc("/*", func(ctx eudore.Context) { - ctx.Info("server body:", string(ctx.Body())) - ctx.Write(ctx.Body()) - }) - app.AnyFunc("/trace", func(ctx eudore.Context) { - ctx.SetHeader(eudore.HeaderXTraceID, "558ac45caefc87c517a7c1cf49918f1aeudore") + app.GetFunc("/body/*", func(eudore.Context) interface{} { + return eudore.MetadataConfig{Name: "config"} }) - app.GetFunc("/body", func(eudore.Context) interface{} { - return eudore.LoggerStdConfig{ - Std: true, - Path: "/tmp/client.log", - Level: eudore.LoggerInfo, - } + app.GetFunc("/proxy", func(ctx eudore.Context) { + ctx.NewRequest(ctx.Method(), "/body", + eudore.NewClientOptionHeader(eudore.HeaderAccept, ctx.GetHeader(eudore.HeaderAccept)), + eudore.NewClientOptionHeader(eudore.HeaderAcceptEncoding, ctx.GetHeader(eudore.HeaderAcceptEncoding)), + eudore.NewClienProxyWriter(ctx.Response()), + ) }) app.GetFunc("/err", func(eudore.Context) error { return fmt.Errorf("test err") }) - app.Debug(app.NewRequest(nil, "GET", "https://goproxy.cn", - eudore.NewClientTimeout(time.Second), - eudore.NewClientTrace(), - eudore.NewClientDumpHead(), - )) - app.Debug(app.NewRequest(nil, "GET", "https://golang.org", - eudore.NewClientTimeout(time.Second), - eudore.NewClientTrace(), - eudore.NewClientDumpHead(), - )) - app.Debug(app.NewRequest(context.Background(), "GET", "/", - eudore.NewClientTimeout(time.Second), - eudore.NewClientDumpHead(), - )) + // app.NewRequest(nil, "GET", "https://goproxy.cn", time.Second, &eudore.ClientTrace{}) + // app.NewRequest(nil, "GET", "https://golang.org", time.Second, &eudore.ClientTrace{}) + // app.NewRequest(context.Background(), "GET", "/", time.Second) app.NewRequest(nil, "GET", "/check/status", eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/check/status", eudore.NewClientCheckStatus(201)) + app.NewRequest(nil, "GET", "/check/status", eudore.NewClientCheckStatus(404)) - app.NewRequest(nil, "GET", "/trace", eudore.NewClientCheckBody("201")) - app.NewRequest(nil, "GET", "/trace", NewClientBodyError(), eudore.NewClientCheckBody("201")) + for _, accept := range []string{eudore.MimeApplicationJSON, eudore.MimeApplicationXML, eudore.MimeApplicationProtobuf, eudore.MimeTextHTML} { + conf := &eudore.MetadataConfig{} + app.NewRequest(nil, "GET", "/body/proxy", + eudore.NewClientOptionHeader(eudore.HeaderAccept, accept), + eudore.NewClientParse(conf), + ) + app.Debugf("%#v", conf) + } - app.NewRequest(nil, "GET", "/trace", - NewClientBodyError(), - eudore.NewClientDumpBody(), - ) + clientBodyError := func(w *http.Response) error { + w.Body = &responseBody{} + return nil + } + var str string + app.NewRequest(nil, "GET", "/body/parse", eudore.NewClientParse(&str)) + app.NewRequest(nil, "GET", "/body/parse", clientBodyError, eudore.NewClientParse(&str)) + app.NewRequest(nil, "GET", "/body/parseif", eudore.NewClientParseIf(201, nil)) + app.NewRequest(nil, "GET", "/body/parsein", eudore.NewClientParseIn(300, 308, nil)) - var conf eudore.LoggerStdConfig - err := app.NewRequest(nil, "GET", "/body", - eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON), - eudore.NewClientParse(&conf), - ) - app.Debugf("%v %v", conf, err) - err = app.NewRequest(nil, "GET", "/body", - eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationXML), - eudore.NewClientTrace(), - eudore.NewClientDumpBody(), - eudore.NewClientParse(&conf), - ) - app.Debugf("%v %v", conf, err) - app.NewRequest(nil, "GET", "/body", - eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML), - eudore.NewClientParse(&conf), - ) - app.NewRequest(nil, "GET", "/body", - eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationProtobuf), - eudore.NewClientParse(&conf), - eudore.NewClientParseErr(), - ) - app.NewRequest(nil, "GET", "/err", - eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON), - eudore.NewClientParseIf(200, &conf), - eudore.NewClientParseIn(200, 200, &conf), - eudore.NewClientParseErr(), - ) - app.NewRequest(nil, "GET", "/err", - eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML), - eudore.NewClientParseErr(), - ) + app.NewRequest(nil, "GET", "/body/parsestr", eudore.NewClientParseErr(), eudore.NewClientParse(&str)) + app.NewRequest(nil, "GET", "/err", eudore.NewClientParseErr(), eudore.NewClientParse(&str)) + app.NewRequest(nil, "GET", "/err", clientBodyError, eudore.NewClientParseErr()) + + app.NewRequest(nil, "GET", "/body/check", eudore.NewClientCheckBody("config")) + app.NewRequest(nil, "GET", "/body/check", eudore.NewClientCheckBody("123456")) + app.NewRequest(nil, "GET", "/body/check", clientBodyError, eudore.NewClientCheckBody("")) + + app.NewRequest(nil, "GET", "/proxy") app.CancelFunc() app.Run() } -func NewClientBodyError() eudore.ClientResponseOption { - return func(w *http.Response) error { - w.Body = &responseBody{} - return nil - } -} - type responseBody struct{} func (r *responseBody) Read(p []byte) (int, error) { return 0, fmt.Errorf("test error") } + func (r *responseBody) Close() error { return nil } diff --git a/_example/configStd.go b/_example/configStd.go index cd99b89..be24371 100644 --- a/_example/configStd.go +++ b/_example/configStd.go @@ -1,42 +1,41 @@ package main /* -ConfigEudore需要使用指定对象保存配置,默认为map[string]interface{}。 +Config需要使用指定对象保存配置,默认为map[string]interface{}。 可以自己指定结构体来保存配置,例如example中Config对象指定的user.name就是展开的一层结构体或map后设置,详细查看eudore.Set函数的文档。 -config的Get & Set方法使用eudore.Get & eudore.Set方法实现。 +config的Get & Set方法使用eudore.GetAnyByPath & eudore.SetAnyByPath方法实现。 */ import ( "github.com/eudore/eudore" ) -type ( - eudoreConfig struct { - Bool bool `alias:"bool"` - Int int `alias:"int"` - String string `alias:"string"` - User user `alias:"user" flag:"u"` - Struct interface{} `alias:"struct"` - } - user struct { - Name string `alias:"name"` - Mail string `alias:"mail"` - } -) +type eudoreConfig struct { + Bool bool `alias:"bool"` + Int int `alias:"int"` + String string `alias:"string"` + User user `alias:"user" flag:"u"` + Any interface{} `alias:"any"` +} +type user struct { + Name string `alias:"name"` + Mail string `alias:"mail"` +} func main() { conf := &eudoreConfig{} app := eudore.NewApp() - app.SetValue(eudore.ContextKeyConfig, eudore.NewConfigStd(conf)) + app.SetValue(eudore.ContextKeyConfig, eudore.NewConfig(conf)) + app.Parse() - // 设属性 + // 设置属性 app.Set("int", 20) app.Set("string", "app set string") app.Set("bool", true) app.Set("user.name", "EudoreName") - app.Set("struct", struct { + app.Set("any", struct { Name string Age int }{"eudore", 2020}) @@ -47,7 +46,8 @@ func main() { app.Debugf("%#v", app.GetInt("string")) app.Debugf("%#v", app.GetString("string")) app.Debugf("%#v", app.GetBool("bool")) - app.Debugf("%#v", app.Get("struct")) + app.Debugf("%#v", app.Get("user")) + app.Debugf("%#v", app.Get("any")) app.Debugf("%#v", app.Get("field")) // 输出全部配置信息 diff --git a/_example/config_test.go b/_example/config_test.go index c3d9272..0a9c664 100644 --- a/_example/config_test.go +++ b/_example/config_test.go @@ -14,11 +14,12 @@ import ( func TestConfigStdGetSet(t *testing.T) { app := eudore.NewApp() - app.SetValue(eudore.ContextKeyConfig, eudore.NewConfigStd(map[string]interface{}{ + app.SetValue(eudore.ContextKeyConfig, eudore.NewConfig(map[string]interface{}{ "name": "eudore", "type": "ConfigMap", "number": 3, })) + app.Parse() app.Set("auth.secret", "secret") app.Infof("data: %# v", app.Get("")) app.Infof("data name: %v", app.Get("name")) @@ -30,6 +31,7 @@ func TestConfigStdGetSet(t *testing.T) { app.Set("", &Config{Name: "eudore"}) app.Set("type", "config") app.Infof("data name: %v", app.Get("name")) + app.Config.(interface{ Metadata() any }).Metadata() app.CancelFunc() app.Run() @@ -37,17 +39,18 @@ func TestConfigStdGetSet(t *testing.T) { func TestConfigStdpParse(t *testing.T) { app := eudore.NewApp() - app.ParseOption([]eudore.ConfigParseFunc{func(ctx context.Context, config eudore.Config) error { + app.ParseOption(func(ctx context.Context, config eudore.Config) error { config.Set("parse", true) return nil - }}) + }) app.Infof("parse eror: %v", app.Parse()) app.Infof("data: %# v", app.Get("")) - app.ParseOption([]eudore.ConfigParseFunc{func(ctx context.Context, config eudore.Config) error { + app.ParseOption() + app.ParseOption(func(ctx context.Context, config eudore.Config) error { config.Set("error", true) return errors.New("parse test error") - }}) + }) app.Infof("parse eror: %v", app.Parse()) app.Infof("parse eror: %v", app.Parse()) app.Infof("data: %# v", app.Get("")) @@ -58,9 +61,9 @@ func TestConfigStdpParse(t *testing.T) { func TestConfigStdJSON(t *testing.T) { app := eudore.NewApp() - app.ParseOption([]eudore.ConfigParseFunc{func(ctx context.Context, config eudore.Config) error { + app.ParseOption(func(ctx context.Context, config eudore.Config) error { return json.Unmarshal([]byte(`{"name":"eudore"}`), config) - }}) + }) app.Infof("ConfigMap parse eror: %v", app.Parse()) app.Infof("ConfigMap data: %# v", app.Get("")) @@ -78,7 +81,8 @@ func TestConfigParseJSON(t *testing.T) { defer tempConfigFile(filepath2, `name:eudore`)() app := eudore.NewApp() - app.ParseOption([]eudore.ConfigParseFunc{eudore.NewConfigParseJSON("config")}) + app.ParseOption() + app.ParseOption(eudore.NewConfigParseJSON("config")) app.Infof("NewConfigParseJSON parse empty error %v:", app.Parse()) @@ -113,9 +117,10 @@ func tempConfigFile(path, content string) func() { } func TestConfigParseArgs(t *testing.T) { - os.Args = append(os.Args, "--name=eudore") + os.Args = append(os.Args, "start", "--name=eudore") app := eudore.NewApp() - app.ParseOption([]eudore.ConfigParseFunc{eudore.NewConfigParseArgs(nil)}) + app.ParseOption() + app.ParseOption(eudore.NewConfigParseArgs(nil)) app.Infof("NewConfigParseArgs parse error: %v", app.Parse()) app.Infof("Config data: %# v", app.Get("")) @@ -136,8 +141,9 @@ func TestConfigParseArgsShort(t *testing.T) { os.Args = append(os.Args, "--name=eudore", "-f=config.json", "-h", "--help") app := eudore.NewApp() - app.SetValue(eudore.ContextKeyConfig, eudore.NewConfigStd(&configShort{false, "eudore", "msg"})) - app.ParseOption([]eudore.ConfigParseFunc{eudore.NewConfigParseArgs(shortMapping)}) + app.SetValue(eudore.ContextKeyConfig, eudore.NewConfig(&configShort{false, "eudore", "msg"})) + app.ParseOption() + app.ParseOption(eudore.NewConfigParseArgs(shortMapping)) app.Infof("NewConfigParseArgs parse error: %v", app.Parse()) app.Infof("Config data: %# v", app.Get("")) @@ -151,7 +157,8 @@ func TestConfigParseEnvs(t *testing.T) { defer os.Unsetenv("ENV_NAME") // init envs by cmd app := eudore.NewApp() - app.ParseOption([]eudore.ConfigParseFunc{eudore.NewConfigParseEnvs("ENV_")}) + app.ParseOption() + app.ParseOption(eudore.NewConfigParseEnvs("ENV_")) app.Infof("NewConfigParseEnvs parse error: %v", app.Parse()) app.Infof("Config data: %# v", app.Get("")) @@ -160,9 +167,31 @@ func TestConfigParseEnvs(t *testing.T) { app.Run() } +func TestConfigParseEnvsFile(t *testing.T) { + defer tempConfigFile(".env", "A=2\r\nB=\r\nC='2\r\n2\r\n2'\r\nD='2\r\n2\r\n2'2\r\n")() + defer os.Unsetenv("ENV_NAME") + // init envs by cmd + + p := "out----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------1.log" + app := eudore.NewApp() + app.ParseOption() + app.ParseOption(eudore.NewConfigParseEnvFile(".env", p)) + + app.Infof("NewConfigParseEnvFile parse error: %v", app.Parse()) + app.Infof("Config data: %# v", app.Get("")) + + app.ParseOption() + app.ParseOption(eudore.NewConfigParseEnvFile()) + app.Parse() + + app.CancelFunc() + app.Run() +} + func TestConfigParseWorkdir(t *testing.T) { app := eudore.NewApp() - app.ParseOption([]eudore.ConfigParseFunc{eudore.NewConfigParseWorkdir("workdir")}) + app.ParseOption() + app.ParseOption(eudore.NewConfigParseWorkdir("workdir")) app.Infof("NewConfigParseWorkdir parse empty dir error: %v", app.Parse()) @@ -179,8 +208,9 @@ func TestConfigParseHelp(t *testing.T) { conf.Link = conf app := eudore.NewApp() - app.SetValue(eudore.ContextKeyConfig, eudore.NewConfigStd(conf)) - app.ParseOption([]eudore.ConfigParseFunc{eudore.NewConfigParseHelp("help")}) + app.SetValue(eudore.ContextKeyConfig, eudore.NewConfig(conf)) + app.ParseOption() + app.ParseOption(eudore.NewConfigParseHelp("help")) app.Infof("NewConfigParseHelp parse not help error: %v", app.Parse()) app.Set("help", true) @@ -221,12 +251,12 @@ type Node struct { // ComponentConfig 定义website使用的组件的配置。 type helpComponentConfig struct { - DB helpDBConfig `json:"db" alias:"db"` - Logger *eudore.LoggerStdConfig `json:"logger" alias:"logger"` - Server *eudore.ServerStdConfig `json:"server" alias:"server"` - Notify map[string]string `json:"notify" alias:"notify"` - Pprof *helpPprofConfig `json:"pprof" alias:"pprof"` - Black map[string]bool `json:"black" alias:"black"` + DB helpDBConfig `json:"db" alias:"db"` + Logger *eudore.LoggerConfig `json:"logger" alias:"logger"` + Server *eudore.ServerConfig `json:"server" alias:"server"` + Notify map[string]string `json:"notify" alias:"notify"` + Pprof *helpPprofConfig `json:"pprof" alias:"pprof"` + Black map[string]bool `json:"black" alias:"black"` } type helpDBConfig struct { Driver string `json:"driver" alias:"driver" description:"database driver type"` diff --git a/_example/contextPush.go b/_example/contextPush.go index 7eeea2a..308a669 100644 --- a/_example/contextPush.go +++ b/_example/contextPush.go @@ -3,57 +3,50 @@ package main import ( "crypto/tls" "net" + "net/http" "github.com/eudore/eudore" - "github.com/eudore/eudore/component/httptest" + "github.com/eudore/eudore/middleware" "golang.org/x/net/http2" ) func main() { app := eudore.NewApp() + app.AddMiddleware( + middleware.NewLoggerFunc(app), + middleware.NewCompressMixinsFunc(nil), + ) app.GetFunc("/", func(ctx eudore.Context) { - ctx.Debug(ctx.Request().Proto) - ctx.Push("/css/1.css", nil) - ctx.Push("/css/2.css", nil) - ctx.Push("/css/3.css", nil) - ctx.Push("/favicon.ico", nil) + ctx.Push("/css/app.css", &http.PushOptions{ + Header: http.Header{eudore.HeaderAuthorization: {"00"}}, + }) + ctx.SetHeader(eudore.HeaderContentType, eudore.MimeTextHTMLCharsetUtf8) ctx.WriteString(` - - push - - - - - -push test - +push +push test, push css is red font. `) }) - app.GetFunc("/hijack", func(ctx eudore.Context) { - conn, _, err := ctx.Response().Hijack() - if err == nil { - conn.Close() - } - }) app.GetFunc("/css/*", func(ctx eudore.Context) { - ctx.WriteString("*{}") + if ctx.GetHeader(eudore.HeaderAuthorization) == "" { + ctx.WriteHeader(eudore.StatusUnauthorized) + return + } + ctx.WithField("header", ctx.Request().Header).Debug() + ctx.SetHeader(eudore.HeaderContentType, "text/css") + ctx.WriteString("*{color: red;}") }) - app.ListenTLS(":8088", "", "") + app.Listen(":8088") + app.ListenTLS(":8089", "", "") - client := httptest.NewClient(app) - client.Client.Transport = &http2.Transport{ + client := app.WithClient(&http2.Transport{ AllowHTTP: true, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { return tls.Dial(network, addr, cfg) }, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client.NewRequest("GET", "/").Do().CheckStatus(200).Out() - client.NewRequest("GET", "https://localhost:8088/").Do().CheckStatus(200).Out() - client.NewRequest("GET", "https://localhost:8088/hijack").Do().CheckStatus(200).Out() + }) + client.NewRequest(nil, "GET", "https://localhost:8089/", eudore.NewClientCheckStatus(200)) - app.Listen(":8088") - // app.CancelFunc() app.Run() } diff --git a/_example/context_test.go b/_example/context_test.go index 3485a41..dd1ff0a 100644 --- a/_example/context_test.go +++ b/_example/context_test.go @@ -8,6 +8,8 @@ import ( "html/template" "net" "net/http" + "net/url" + "strings" "testing" "github.com/eudore/eudore" @@ -38,7 +40,7 @@ func TestContext(*testing.T) { func TestContextRequest(*testing.T) { app := eudore.NewApp() - app.SetValue(eudore.ContextKeyValidate, func(eudore.Context, interface{}) error { return nil }) + app.SetValue(eudore.ContextKeyValidater, func(eudore.Context, interface{}) error { return nil }) app.SetValue(eudore.ContextKeyContextPool, eudore.NewContextBasePool(app)) app.AddMiddleware("global", middleware.NewRequestIDFunc(nil)) @@ -76,16 +78,17 @@ func TestContextRequest(*testing.T) { app.NewRequest(nil, "GET", "/info") app.NewRequest(nil, "GET", "/realip") - app.NewRequest(nil, "GET", "/realip", eudore.NewClientHeader(eudore.HeaderXRealIP, "47.11.11.11")) - app.NewRequest(nil, "GET", "/realip", eudore.NewClientHeader(eudore.HeaderXForwardedFor, "47.11.11.11")) + app.NewRequest(nil, "GET", "http://localhost:8088/realip") + app.NewRequest(nil, "GET", "/realip", http.Header{eudore.HeaderXRealIP: {"47.11.11.11"}}) + app.NewRequest(nil, "GET", "/realip", http.Header{eudore.HeaderXForwardedFor: {"47.11.11.11"}}) app.NewRequest(nil, "GET", "/bind", eudore.NewClientBodyJSON(bindData{"eudore"})) app.NewRequest(nil, "GET", "/bind", - eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationJSON), - eudore.NewClientBodyString("eudore"), + http.Header{eudore.HeaderContentType: {eudore.MimeApplicationJSON}}, + strings.NewReader("eudore"), ) app.NewRequest(nil, "POST", "/bind", - eudore.NewClientHeader(eudore.HeaderContentType, "value"), - eudore.NewClientBodyString("eudore"), + http.Header{eudore.HeaderContentType: {"value"}}, + strings.NewReader("eudore"), ) app.CancelFunc() @@ -104,6 +107,24 @@ func (bodyError) Close() error { func TestContextData(*testing.T) { app := eudore.NewApp() + app.AddMiddleware(func(ctx eudore.Context) { + d := ctx.GetHeader("Debug") + if d == "" { + return + } + r := ctx.Request() + switch d { + case "uri": + r.URL.RawQuery = "tag=%\007" + case "cookie": + r.Header.Add(eudore.HeaderCookie, "age=22; =00; tag=\007hs; aa=\"bb\"; ") + case "body": + r.Body = bodyError{} + } + }) + app.AnyFunc("/body", func(ctx eudore.Context) { + ctx.Body() + }) app.AnyFunc("/* version=v0", func(ctx eudore.Context) { ctx.Info(ctx.Method(), ctx.Path(), ctx.Params().String(), string(ctx.Body())) }) @@ -111,22 +132,15 @@ func TestContextData(*testing.T) { ctx.SetParam("name", "eudore") ctx.Info("params", ctx.Params().String(), ctx.GetParam("name")) }) - app.AnyFunc("/querys", func(ctx eudore.Context) { - ctx.Debug(string(ctx.Body()), ctx.Request().RequestURI) - ctx.Info("querys", ctx.Querys()) + app.AnyFunc("/query", func(ctx eudore.Context) { ctx.Info("query name", ctx.GetQuery("name")) }) - app.AnyFunc("/querys-err1", func(ctx eudore.Context) { - ctx.Request().URL.RawQuery = "tag=%\007" + app.AnyFunc("/querys", func(ctx eudore.Context) { ctx.Info("querys", ctx.Querys()) }) - app.AnyFunc("/querys-err2", func(ctx eudore.Context) { - ctx.Request().URL.RawQuery = "tag=%\007" - ctx.Info("query name", ctx.GetQuery("name")) - }) // cookie app.AnyFunc("/cookie-set", func(ctx eudore.Context) { - ctx.SetCookie(&eudore.SetCookie{ + ctx.SetCookie(&eudore.CookieSet{ Name: "set1", Value: "val1", Path: "/", @@ -138,58 +152,88 @@ func TestContextData(*testing.T) { }) app.AnyFunc("/cookie-get", func(ctx eudore.Context) { ctx.Info("cookie", ctx.GetHeader(eudore.HeaderCookie)) - ctx.Infof("cookie name value is: %s", ctx.GetCookie("name")) - ctx.Infof("cookie age value is: %s", ctx.GetCookie("age")) + ctx.GetCookie("name") for _, i := range ctx.Cookies() { fmt.Fprintf(ctx, "%s: %s\n", i.Name, i.Value) } }) - app.AnyFunc("/cookie-err", func(ctx eudore.Context) { - ctx.Request().Header.Add(eudore.HeaderCookie, "age=22; =00; tag=\007hs; aa=\"bb\"; ") - ctx.Info("cookies", ctx.Cookies()) - }) // form app.AnyFunc("/form-value", func(ctx eudore.Context) { ctx.Info("form value name:", ctx.FormValue("name")) - ctx.Info("form value group:", ctx.FormValue("group")) + }) + app.AnyFunc("/form-values", func(ctx eudore.Context) { ctx.Info("form values:", ctx.FormValues()) }) app.AnyFunc("/form-file", func(ctx eudore.Context) { - ctx.Infof("%s", ctx.Body()) - ctx.Infof("form value name: %#v", ctx.FormFile("file")) - ctx.Infof("form value group: %#v", ctx.FormFile("name")) - ctx.Infof("form values: %#v", ctx.FormFiles()) - }) - app.AnyFunc("/form-err", func(ctx eudore.Context) { - ctx.FormValue("name") - ctx.FormValues() - ctx.FormFile("file") - ctx.FormFiles() - }) - app.AnyFunc("/body", func(ctx eudore.Context) { - ctx.Request().Body = bodyError{} - ctx.Body() + ctx.Infof("form file: %#v", ctx.FormFile("file")) }) - app.AnyFunc("/read", func(ctx eudore.Context) { - body := make([]byte, 4096) - ctx.Read(body) + app.AnyFunc("/form-files", func(ctx eudore.Context) { + ctx.Infof("form values: %#v", ctx.FormFiles()) }) app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/body") + app.NewRequest(nil, "GET", "/body", eudore.NewClientOptionHeader("Debug", "body")) app.NewRequest(nil, "GET", "/params") + app.NewRequest(nil, "GET", "/query?name=eudore&debug=true") app.NewRequest(nil, "GET", "/querys?name=eudore&debug=true") - app.NewRequest(nil, "PUT", "/querys-err1") - app.NewRequest(nil, "PUT", "/querys-err2") - app.NewRequest(nil, "GET", "/cookie-get") + app.NewRequest(nil, "PUT", "/query", eudore.NewClientOptionHeader("Debug", "uri")) + app.NewRequest(nil, "PUT", "/querys", eudore.NewClientOptionHeader("Debug", "uri")) app.NewRequest(nil, "GET", "/cookie-set") - app.NewRequest(nil, "GET", "/cookie-get") - app.NewRequest(nil, "GET", "/cookie-get", http.Header{eudore.HeaderCookie: []string{"age=22"}}) - app.NewRequest(nil, "GET", "/cookie-err") - app.NewRequest(nil, "GET", "/form-value", eudore.NewClientBodyFormValue("name", "eudore")) - app.NewRequest(nil, "GET", "/form-file", eudore.NewClientBodyFormFile("file", "app.txt", "eudore app")) - app.NewRequest(nil, "GET", "/form-err", eudore.NewClientBodyString("name=eudore")) - app.NewRequest(nil, "GET", "/body") - app.NewRequest(nil, "GET", "/read") + app.NewRequest(nil, "GET", "/cookie-get", + eudore.Cookie{"age", "22"}, + eudore.Cookie{"name", "a, b"}, + eudore.Cookie{"valid", "key\x03invalid"}, + http.Header{eudore.HeaderCookie: {"age=22;;;"}}, + ) + app.NewRequest(nil, "GET", "/cookie-get", eudore.NewClientOptionHeader("Debug", "cookie")) + app.NewRequest(nil, "GET", "/form-value", eudore.NewClientBodyForm(url.Values{"name": {"eudor"}})) + app.NewRequest(nil, "GET", "/form-value", eudore.NewClientBodyForm(url.Values{"key": {"eudor"}})) + app.NewRequest(nil, "GET", "/form-values?name=eudore") + app.NewRequest(nil, "GET", "/form-values", eudore.NewClientOptionHeader("Debug", "uri")) + + body := eudore.NewClientBodyForm(nil) + body.AddFile("file", "app.txt", strings.NewReader("eudore app")) + app.NewRequest(nil, "GET", "/form-file", body) + body = eudore.NewClientBodyForm(nil) + body.AddFile("name", "app.txt", strings.NewReader("eudore app")) + app.NewRequest(nil, "GET", "/form-file", body) + app.NewRequest(nil, "GET", "/form-file", + eudore.NewClientOptionHeader(eudore.HeaderContentType, eudore.MimeText), + strings.NewReader("body"), + ) + body = eudore.NewClientBodyForm(nil) + body.AddFile("file", "app.txt", strings.NewReader("eudore app")) + app.NewRequest(nil, "GET", "/form-files", body) + app.NewRequest(nil, "GET", "/form-files") + app.NewRequest(nil, "GET", "/form-files", + eudore.NewClientOptionHeader(eudore.HeaderContentType, eudore.MimeText), + strings.NewReader("body"), + ) + + app.NewRequest(nil, "GET", "/form-value", + eudore.NewClientOptionHeader("Debug", "body"), + eudore.NewClientOptionHeader(eudore.HeaderContentType, eudore.MimeApplicationForm), + ) + app.NewRequest(nil, "GET", "/form-value", + strings.NewReader("name=%\007"), + ) + app.NewRequest(nil, "GET", "/form-value", + eudore.NewClientOptionHeader(eudore.HeaderContentType, eudore.MimeApplicationForm), + strings.NewReader("name=%\007"), + ) + app.NewRequest(nil, "GET", "/form-value", + eudore.NewClientOptionHeader(eudore.HeaderContentType, eudore.MimeMultipartForm), + strings.NewReader("body"), + ) + app.NewRequest(nil, "GET", "/form-value", + eudore.NewClientOptionHeader(eudore.HeaderContentType, eudore.MimeMultipartForm+"; boundary=x"), + strings.NewReader("body"), + ) + app.NewRequest(nil, "GET", "/form-value", + eudore.NewClientOptionHeader(eudore.HeaderContentType, eudore.MimeText), + strings.NewReader("body"), + ) app.CancelFunc() app.Run() @@ -208,6 +252,10 @@ func (w *responseError) Write([]byte) (int, error) { return 0, fmt.Errorf("test response Write error") } +func (w *responseError) WriteString(string) (int, error) { + return 0, fmt.Errorf("test response Write error") +} + func (w *responseError) WriteHeader(code int) { w.code = code } @@ -232,9 +280,13 @@ func (w *responseError) Status() int { func TestContextResponse(*testing.T) { app := eudore.NewApp() - app.SetValue(eudore.ContextKeyFilte, func(eudore.Context, interface{}) error { return nil }) + app.SetValue(eudore.ContextKeyFilter, func(eudore.Context, interface{}) error { return nil }) app.SetValue(eudore.ContextKeyContextPool, eudore.NewContextBasePool(app)) app.AddMiddleware(func(ctx eudore.Context) { + unwarper, ok := ctx.Response().(interface{ Unwrap() http.ResponseWriter }) + if ok { + unwarper.Unwrap() + } if ctx.GetQuery("debug") != "" { ctx.SetResponse(&responseError{headers: make(http.Header)}) } @@ -246,6 +298,9 @@ func TestContextResponse(*testing.T) { app.AnyFunc("/redirect", func(ctx eudore.Context) { ctx.Redirect(308, "/") }) + app.AnyFunc("/redirect200", func(ctx eudore.Context) { + ctx.Redirect(200, "/") + }) app.AnyFunc("/ws", func(ctx eudore.Context) { conn, _, err := ctx.Response().Hijack() if err == nil { @@ -285,24 +340,25 @@ func TestContextResponse(*testing.T) { app.ListenTLS(":8089", "", "") app.NewRequest(nil, "GET", "/redirect") + app.NewRequest(nil, "GET", "/redirect200") app.NewRequest(nil, "GET", "/push") app.NewRequest(nil, "GET", "/ws") app.NewRequest(nil, "GET", "/write-string") app.NewRequest(nil, "GET", "/write-file") - app.NewRequest(nil, "GET", "/render", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) + app.NewRequest(nil, "GET", "/render", http.Header{eudore.HeaderAccept: {eudore.MimeApplicationJSON}}) app.NewRequest(nil, "GET", "/status") app.NewRequest(nil, "GET", "/response") app.NewRequest(nil, "GET", "https://localhost:8089/push") app.NewRequest(nil, "GET", "https://localhost:8089/ws") - app.Client = app.WithClient(eudore.NewClientQuery("debug", "1")) + app.Client = app.WithClient(url.Values{"debug": {"1"}}) app.NewRequest(nil, "GET", "/redirect") app.NewRequest(nil, "GET", "/push") app.NewRequest(nil, "GET", "/response") app.NewRequest(nil, "GET", "/write-json") app.NewRequest(nil, "GET", "/write-string") app.NewRequest(nil, "GET", "/write-file") - app.NewRequest(nil, "GET", "/render", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) + app.NewRequest(nil, "GET", "/render", http.Header{eudore.HeaderAccept: {eudore.MimeApplicationJSON}}) app.CancelFunc() app.Run() @@ -310,7 +366,9 @@ func TestContextResponse(*testing.T) { func TestContextLogger(*testing.T) { app := eudore.NewApp() - app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerStd(map[string]interface{}{"FileLine": true})) + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(&eudore.LoggerConfig{ + Caller: true, + })) app.AddMiddleware("global", middleware.NewRequestIDFunc(func(eudore.Context) string { return uuid.New().String() })) @@ -348,15 +406,15 @@ func TestContextLogger(*testing.T) { ctx.Debug("err:", ctx.Err()) }) app.AnyFunc("/err2", func(ctx eudore.Context) { - ctx.Fatal(eudore.NewErrorStatusCode(fmt.Errorf("test error"), 432, 10032)) + ctx.Fatal(eudore.NewErrorWithStatusCode(fmt.Errorf("test error"), 432, 10032)) }) app.AnyFunc("/err3", func(ctx eudore.Context) { - eudore.NewErrorStatus(fmt.Errorf("test error"), 0) - ctx.Fatal(eudore.NewErrorStatus(fmt.Errorf("test error"), 432)) + eudore.NewErrorWithStatus(fmt.Errorf("test error"), 0) + ctx.Fatal(eudore.NewErrorWithStatus(fmt.Errorf("test error"), 432)) }) app.AnyFunc("/err4", func(ctx eudore.Context) { - eudore.NewErrorCode(fmt.Errorf("test error"), 0) - ctx.Fatal(eudore.NewErrorCode(fmt.Errorf("test error"), 10032)) + eudore.NewErrorWithCode(fmt.Errorf("test error"), 0) + ctx.Fatal(eudore.NewErrorWithCode(fmt.Errorf("test error"), 10032)) }) app.NewRequest(nil, "GET", "/ffile") @@ -391,8 +449,8 @@ func TestContextValue(*testing.T) { return ctx.Err() }) - app.NewRequest(nil, "GET", "/index", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) - app.NewRequest(nil, "GET", "/index", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML)) + app.NewRequest(nil, "GET", "/index", http.Header{eudore.HeaderAccept: {eudore.MimeApplicationJSON}}) + app.NewRequest(nil, "GET", "/index", http.Header{eudore.HeaderAccept: {eudore.MimeTextHTML}}) app.NewRequest(nil, "GET", "/cannel") app.CancelFunc() diff --git a/_example/controllerAutoRoute.go b/_example/controllerAutoRoute.go index e17c61b..961dd2d 100644 --- a/_example/controllerAutoRoute.go +++ b/_example/controllerAutoRoute.go @@ -32,21 +32,13 @@ ControllerRoute返回路径为'-'则忽略方法,第一个字符为' '表示 import ( "github.com/eudore/eudore" - "github.com/eudore/eudore/component/httptest" ) func main() { app := eudore.NewApp() app.AddController(new(autoController)) - client := httptest.NewClient(app) - client.NewRequest("GET", "/auto/index").Do().Out() - client.NewRequest("GET", "/auto/info/22").Do().Out() - client.NewRequest("POST", "/auto").Do().Out() - client.NewRequest("POST", "/auto/").Do().Out() - app.Listen(":8088") - // app.CancelFunc() app.Run() } @@ -70,11 +62,6 @@ func (*autoController) GetInfoById(ctx eudore.Context) interface{} { return ctx.GetParam("id") } -// String 方法返回控制器名称,响应Router.AddController输出的名称。 -func (*autoController) String() string { - return "hello.autoController" -} - // Help 方法定义一个控制器本身的方法。 func (*autoController) Help(ctx eudore.Context) {} diff --git a/_example/controller_test.go b/_example/controller_test.go index fae9c7f..90bcf7b 100644 --- a/_example/controller_test.go +++ b/_example/controller_test.go @@ -159,3 +159,20 @@ func (ctl *tableController) Hello() interface{} { func (ctl *tableController) Any(ctx eudore.Context) { ctx.Debug("tableController Any", ctl.Hello()) } + +type typeController[T any] struct { + eudore.ControllerAutoRoute +} + +func TestControllerTypeName(t *testing.T) { + app := eudore.NewApp() + app.AddController(&typeController[int]{}) + + app.CancelFunc() + app.Run() +} + +func (ctl *typeController[T]) Any(ctx eudore.Context) { + var t T + ctx.Debugf("typeController is %T", t) +} diff --git a/_example/converter_test.go b/_example/converter_test.go deleted file mode 100644 index 7589891..0000000 --- a/_example/converter_test.go +++ /dev/null @@ -1,359 +0,0 @@ -package eudore_test - -import ( - "context" - "encoding/json" - "fmt" - "reflect" - "testing" - "time" - "unsafe" - - "github.com/eudore/eudore" - "github.com/kr/pretty" -) - -type ( - config017 struct { - InterfaceType context.Context - Ptr *config017 - InterfaceNone interface{} - MapString map[string]*Param017 - MapInt map[int]string - Struct Param017 - Array [4]string - Slice []string - Chan chan int `json:"-"` - } - Param017 struct { - Key string - Value string - none string - } - config018 struct { - InterfaceType context.Context - Ptr *config018 - InterfaceNone interface{} - MapString map[string]*Param018 - MapInt map[int]string - Struct Param018 - Array [4]string - Slice []string - Chan chan int `json:"-"` - } - Param018 struct { - Key string - none string - } -) - -func TestConvertValueSet(t *testing.T) { - config := &config017{} - // setValue - eudore.Set(config, "Chan.1", "value") - // setInterface - eudore.Set(config, "InterfaceType.Key", "value") - eudore.Set(config, "InterfaceNone.Key", "value") - // setStruct - eudore.Set(config, "Struct.Key", "value") - eudore.Set(config, "Struct.none", "value") - eudore.Set(config, "Struct.no", "value") - // setMap - eudore.Set(config, "MapString.Key.Key", "value1") - eudore.Set(config, "MapString.Key.Key", "value2") - // setArray - eudore.Set(config, "Array.0", "000") - eudore.Set(config, "Array.4", "4") - // setSlice - eudore.Set(config, "Slice.0", "000") - eudore.Set(config, "Slice.4", "4") - eudore.Set(config, "Slice.x", "x") - // error - eudore.Set(nil, "Slice.x", "x") - eudore.Set(TestConvertValueSet, "Slice.x", "x") - - body, err := json.Marshal(config) - t.Log(string(body), err) - if string(body) != `{"InterfaceType":null,"Ptr":null,"InterfaceNone":{"Key":"value"},"MapString":{"Key":{"Key":"value2","Value":""}},"MapInt":null,"Struct":{"Key":"value","Value":""},"Array":["000","","",""],"Slice":["000","","","","4","x"]}` { - panic("check result") - } -} - -func TestConvertValueGet(t *testing.T) { - config := &config017{} - // getValue - eudore.Get(config, "InterfaceNone.Key") - eudore.Get(config, "Chan.1") - // getStruct - config.Struct = Param017{ - Key: "value", - none: "value", - } - eudore.Get(config, "Struct.Key") - eudore.Get(config, "Struct.none") - eudore.Get(config, "Struct.no") - // getMap - eudore.Get(config, "MapString.Key") - config.MapString = map[string]*Param017{ - "Key": {Key: "String"}, - } - eudore.Get(config, "MapString.Key") - eudore.Get(config, "MapString.Key2") - config.MapInt = map[int]string{ - 1: "int", - } - eudore.Get(config, "MapInt.Key") - // getSlice - eudore.Get(config, "Slice.x") - config.Slice = []string{"000", "", "", "", "4"} - eudore.Get(config, "Slice.0") - eudore.Get(config, "Slice.4") - eudore.Get(config, "Slice.x") - - // error - eudore.Get(nil, "Slice.x") - eudore.GetWithTags(config, "Slice.0", []string{"alias"}, false) -} - -func TestConvertMappingMapString(t *testing.T) { - config := &config017{} - config.InterfaceNone = config - config.MapString = map[string]*Param017{ - "Key": {Key: "String"}, - "nil": nil, - } - eudore.ConvertMapString(config) - eudore.ConvertMap(config) - - var data map[string]interface{} - eudore.ConvertTo(config, &data) - - config2 := &config017{} - eudore.ConvertTo(&data, config2) -} - -func TestConvertMappingTo(t *testing.T) { - config := &config017{} - config.Ptr = config - config.InterfaceNone = config - config.MapString = map[string]*Param017{ - "Key": {Key: "String"}, - "nil": nil, - } - config.MapInt = map[int]string{ - 1: "A", - 2: "B", - } - config.Struct = Param017{Key: "name", Value: "018"} - config.Array = [4]string{"str1", "str2", "str3", ""} - config.Slice = []string{"Slice1", "Slice2", "Slice3", ""} - config.InterfaceType = context.WithValue(context.Background(), "key", "name") - - config2 := &config017{} - conv(config, config2) - - var config3 map[string]interface{} - conv(config, &config3) - - var config4 map[string]interface{} - conv(&config3, &config4) - - var config5 map[string]interface{} - config5 = make(map[string]interface{}) - conv(config3, &config5) - - config6 := &config017{} - config3["none"] = "none" - conv(&config3, config6) - - config7 := &config018{} - config7.InterfaceType = context.Background() - conv(config, config7) - - eudore.ConvertTo(config, config4) - eudore.ConvertTo(config, nil) - eudore.ConvertTo(nil, config4) -} - -func conv(a, b interface{}) { - eudore.ConvertTo(a, b) - // 1.13 not json - // return - fmt.Printf("%# v\n", a) - body, err := json.Marshal(a) - fmt.Printf("%s %v\n", body, err) - - fmt.Printf("%# v\n", b) - body, err = json.Marshal(b) - fmt.Printf("%s %v\n", body, err) - fmt.Println("------------------") -} - -type BB struct { - BB1 int - BB2 string -} - -type withString struct { - String string `alias:"string"` - Bytes []byte `alias:"bytes"` - Int int `alias:"int"` - Int8 int8 `alias:"int8"` - Int16 int16 `alias:"int16"` - Int32 int32 `alias:"int32"` - Int64 int64 `alias:"int64"` - Uint uint `alias:"uint"` - Uint8 uint8 `alias:"uint8"` - Uint16 uint16 `alias:"uint16"` - Uint32 uint32 `alias:"uint32"` - Uint64 uint64 `alias:"uint64"` - Bool bool `alias:"bool"` - Float32 float32 `alias:"float32"` - Float64 float64 `alias:"float64"` - Complex64 complex64 `alias:"complex64"` - Complex128 complex128 `alias:"complex128"` - Time time.Time `alias:"time"` - Time2 Time020 `alias:"time2"` - Duration time.Duration `alias:"duration"` - Interface interface{} `alias:"interface"` - Struct BB `alias:"struct"` - StructNil BB `alias:"structnil"` - Map map[interface{}]interface{} `alias:"map"` - MapNil map[interface{}]interface{} `alias:"mapnil"` - MapString map[string]interface{} `alias:"mapstring"` - MapPtr map[*int]interface{} `alias:"mapptr"` - SliceInt []int `alias:"sliceint"` - SliceNil []int `alias:"slicenil"` - ArrayInt [10]int `alias:"arrayint"` - ArrayInt2 [10]int `alias:"arrayint2"` - Ptr *int `alias:"ptr"` - PtrMap *map[string]interface{} `alias:"ptrmap"` - Unsafe unsafe.Pointer `alias:"unsafe"` -} - -type Time020 time.Time - -func TestSetWithString2(t *testing.T) { - data := &withString{ - Unsafe: unsafe.Pointer(t), - } - t.Log(eudore.Set(data, "string", TestSetWithString2)) - t.Log(eudore.Set(data, "string", []byte("666s"))) - t.Log(eudore.Set(data, "bytes", "[]byte")) - t.Log(eudore.Set(data, "int", "")) - t.Log(eudore.Set(data, "int", []int{1, 2, 3})) - t.Log(eudore.Set(data, "int", []byte("123"))) - t.Log(eudore.Set(data, "int", "1")) - t.Log(eudore.Set(data, "int8", "2")) - t.Log(eudore.Set(data, "int16", "3")) - t.Log(eudore.Set(data, "int32", "4")) - t.Log(eudore.Set(data, "int64", "5")) - t.Log(eudore.Set(data, "uint", "")) - t.Log(eudore.Set(data, "uint", "1")) - t.Log(eudore.Set(data, "uint8", "2")) - t.Log(eudore.Set(data, "uint16", "3")) - t.Log(eudore.Set(data, "uint32", "4")) - t.Log(eudore.Set(data, "uint64", "5")) - t.Log(eudore.Set(data, "uint64", 6)) - t.Log(eudore.Set(data, "bool", "")) - t.Log(eudore.Set(data, "bool", "true")) - t.Log(eudore.Set(data, "float32", "")) - t.Log(eudore.Set(data, "float32", "16")) - t.Log(eudore.Set(data, "float64", "32")) - t.Log(eudore.Set(data, "complex64", "1+b")) - t.Log(eudore.Set(data, "complex64", "a+b")) - t.Log(eudore.Set(data, "complex64", "(1.2")) - t.Log(eudore.Set(data, "complex128", "1.2+3.4i")) - t.Log(eudore.Set(data, "duration", "")) - t.Log(eudore.Set(data, "duration", "3m")) - t.Log(eudore.Set(data, "duration", "30s")) - t.Log(eudore.Set(data, "duration", "30sxx")) - t.Log(eudore.Set(data, "time", "")) - t.Log(eudore.Set(data, "time", "2018-08-12")) - t.Log(eudore.Set(data, "time2", "2018-08-12")) - t.Log(eudore.Set(data, "interface", "eface")) - t.Log(eudore.Set(data, "struct", `{"BB1":22,"BB2":"22"}`)) - t.Log(eudore.Set(data, "map.eface", "str")) - t.Log(eudore.Set(data, "mapptr.2", "str")) - t.Log(eudore.Set(data, "mapstring", `{"a":1,"b":2}`)) - t.Log(eudore.Set(data, "sliceint", "[1]")) - t.Log(eudore.Set(data, "sliceint.2", "12")) - t.Log(eudore.Set(data, "arrayint.2", "12")) - t.Log(eudore.Set(data, "arrayint.12", "12")) - t.Log(eudore.Set(data, "ptr", "1234")) - t.Log(eudore.Set(data, "interface", data)) - t.Log(eudore.Set(data, "unsafe", "666")) - t.Log(eudore.Set(data, "unsafe", 666)) - t.Log(eudore.Set(data, "unsafe", unsafe.Pointer(t))) - t.Logf("%p\n", t) - t.Logf("struct: %# v\n", pretty.Formatter(data)) - - t.Log(eudore.GetWithTags(data, "mapnil.a", eudore.DefaultConvertTags, false)) - t.Log(eudore.GetWithTags(data, "mapptr.a", eudore.DefaultConvertTags, false)) - t.Log(eudore.GetWithTags(data, "slicenil.a", eudore.DefaultConvertTags, false)) - t.Log(eudore.GetWithTags(data, "sliceint.+", eudore.DefaultConvertTags, false)) - - var dm map[string]interface{} - eudore.ConvertTo(data, &dm) - t.Logf("map[string]interface{}: %# v\n", pretty.Formatter(dm)) - dm["nnn"] = 666 - eudore.ConvertTo(dm, data) - eudore.ConvertMap(data) - eudore.ConvertMapString(data) - eudore.ConvertMap(66) - eudore.ConvertMapString(66) - - var ii interface{} - t.Log(reflect.Indirect(reflect.ValueOf(ii)).Kind()) -} - -func Benchmark_iterator(b *testing.B) { - b.StopTimer() //调用该函数停止压力测试的时间计数 - - //做一些初始化的工作,例如读取文件数据,数据库连接之类的, - //这样这些时间不影响我们测试函数本身的性能 - - b.StartTimer() //重新开始时间 - b.ReportAllocs() - type Stu1 struct { - Name string - Age int - HIgh bool - sex string - } - type Stu2 struct { - Name string - Age int - HIgh bool - sex string - Stu1 Stu1 - } - type Stu3 struct { - Name string - Age int - HIgh bool - sex string - Stu2 Stu2 - } - stu := Stu3{ - Name: "张三3", - Age: 183, - HIgh: true, - sex: "男3", - Stu2: Stu2{ - Name: "张三2", - Age: 182, - HIgh: true, - sex: "男2", - Stu1: Stu1{ - Name: "张三1", - Age: 181, - HIgh: true, - sex: "男1", - }, - }, - } - for i := 0; i < b.N; i++ { - eudore.ConvertMap(stu) - } -} diff --git a/_example/database_test.go b/_example/database_test.go new file mode 100644 index 0000000..b1d80a2 --- /dev/null +++ b/_example/database_test.go @@ -0,0 +1,69 @@ +package eudore_test + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/eudore/eudore" +) + +func TestDatabase(*testing.T) { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyDatabase, &databaseTest{}) + app.SetValue(eudore.ContextKeyDatabaseRuntime, eudore.NewDatabaseRuntime) + app.SetValue(eudore.ContextKeyContextPool, eudore.NewContextBasePool(app)) + + app.GetFunc("/hello", func(ctx eudore.Context) { + ctx.Query(nil, nil) + ctx.Exec(nil) + + ctx.SetValue(eudore.ContextKeyDatabase, &databaseTest{Err: fmt.Errorf("test error")}) + ctx.Query(nil, nil) + ctx.SetHeader(eudore.HeaderXTraceID, "id") + ctx.Exec(nil) + }) + + app.NewRequest(nil, "GET", "/hello", + strings.NewReader("trace"), + eudore.NewClientCheckStatus(200), + ) + + app.CancelFunc() + app.Run() +} + +type databaseTest struct { + eudore.Database + Err error +} + +func (db *databaseTest) Query(ctx context.Context, data interface{}, stmt eudore.DatabaseStmt) error { + builder := &databaseBuilder{} + stmt.Build(builder) + return db.Err +} + +func (db *databaseTest) Exec(ctx context.Context, stmt eudore.DatabaseStmt) error { + builder := &databaseBuilder{} + stmt.Build(builder) + return db.Err +} + +type databaseBuilder struct { +} + +func (builder *databaseBuilder) Context() context.Context { + return nil +} +func (builder *databaseBuilder) DriverName() string { + return "" +} +func (builder *databaseBuilder) Metadata(interface{}) interface{} { + return nil +} +func (builder *databaseBuilder) WriteStmts(...interface{}) {} +func (builder *databaseBuilder) Result() (string, []interface{}, error) { + return "", nil, nil +} diff --git a/_example/funccreator_test.go b/_example/funccreator_test.go new file mode 100644 index 0000000..8166f46 --- /dev/null +++ b/_example/funccreator_test.go @@ -0,0 +1,189 @@ +package eudore_test + +import ( + "context" + "testing" + "time" + + "github.com/eudore/eudore" +) + +func TestFuncCreator(t *testing.T) { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyFuncCreator, eudore.NewFuncCreator()) + + fc := eudore.NewFuncCreatorWithContext(app) + t.Log(fc.RegisterFunc("zero", func() {})) + t.Log(fc.CreateFunc(eudore.FuncCreateString, "zero1")) + t.Log(fc.CreateFunc(eudore.FuncCreateString, "zero=1")) + t.Log(fc.CreateFunc(eudore.FuncCreateKind(0), "zero")) + + fc.(interface{ Metadata() any }).Metadata() +} + +func mustCreate(i any, _ error) any { + return i +} + +func TestFuncCreatorRun(t *testing.T) { + fc := eudore.NewFuncCreatorWithContext(context.Background()) + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "zero")).(func(string) bool)("zero") + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "zero")).(func(any) bool)("zero") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "nozero")).(func(string) bool)("zero") + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "nozero")).(func(any) bool)("zero") + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "nozero")).(func(any) bool)(time.Time{}) + + mustCreate(fc.CreateFunc(eudore.FuncCreateInt, "min=0")).(func(int) bool)(0) + mustCreate(fc.CreateFunc(eudore.FuncCreateInt, "max=0")).(func(int) bool)(0) + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "min=0")).(func(string) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "min=0")).(func(string) bool)("x0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "max=0")).(func(string) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "max=0")).(func(string) bool)("x0") + fc.CreateFunc(eudore.FuncCreateInt, "min=0") + fc.CreateFunc(eudore.FuncCreateUint, "min=0") + fc.CreateFunc(eudore.FuncCreateFloat, "min=0") + fc.CreateFunc(eudore.FuncCreateBool, "min=0") + fc.CreateFunc(eudore.FuncCreateInt, "min=x0") + fc.CreateFunc(eudore.FuncCreateInt, "max=x0") + fc.CreateFunc(eudore.FuncCreateString, "min=x0") + fc.CreateFunc(eudore.FuncCreateString, "max=x0") + + mustCreate(fc.CreateFunc(eudore.FuncCreateInt, "equal=0")).(func(int) bool)(0) + mustCreate(fc.CreateFunc(eudore.FuncCreateInt, "enum=1,2,3")).(func(int) bool)(0) + mustCreate(fc.CreateFunc(eudore.FuncCreateInt, "enum=1,2,3")).(func(int) bool)(1) + mustCreate(fc.CreateFunc(eudore.FuncCreateInt, "enum=1,2,3,4,5,6,7,8,9")).(func(int) bool)(0) + fc.CreateFunc(eudore.FuncCreateInt, "equal=x0") + fc.CreateFunc(eudore.FuncCreateInt, "enum=x0") + + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "len=0")).(func(string) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "len!=0")).(func(string) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "len>0")).(func(string) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "len<0")).(func(string) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "len=0")).(func(any) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "len>0")).(func(any) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "len<0")).(func(any) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "len=0")).(func(any) bool)(true) + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "len>0")).(func(any) bool)(true) + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "len<0")).(func(any) bool)(true) + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "after:20180801")).(func(any) bool)(time.Now()) + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "after:20180801")).(func(any) bool)(nil) + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "before:20180801")).(func(any) bool)(time.Now()) + mustCreate(fc.CreateFunc(eudore.FuncCreateAny, "before:20180801")).(func(any) bool)(nil) + fc.CreateFunc(eudore.FuncCreateString, "len=x0") + fc.CreateFunc(eudore.FuncCreateAny, "len=x0") + fc.CreateFunc(eudore.FuncCreateAny, "after:") + fc.CreateFunc(eudore.FuncCreateAny, "before:") + + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "num")).(func(string) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "integer")).(func(string) bool)("0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "integer")).(func(string) bool)("x0") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "domain")).(func(string) bool)("www.eudore.cn") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "mail")).(func(string) bool)("postmaster@eudore.cn") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "mail")).(func(string) bool)("eudore.cn") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "phone")).(func(string) bool)("15824681234") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "phone")).(func(string) bool)("+86 15824681234") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "phone")).(func(string) bool)("010-32221234") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "phone")).(func(string) bool)("xx010-32221234") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "regexp=^\\d+$")).(func(string) bool)("123456") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "patten=123456")).(func(string) bool)("123456") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "patten=*")).(func(string) bool)("123456") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "patten!=a*")).(func(string) bool)("123456") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "patten!=a*b")).(func(string) bool)("a") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "patten=a*b")).(func(string) bool)("axb") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "prefix=123")).(func(string) bool)("123456") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "count=1,a")).(func(string) bool)("1aa23456") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "count>1,a")).(func(string) bool)("1aa23456") + mustCreate(fc.CreateFunc(eudore.FuncCreateString, "count<1,a")).(func(string) bool)("1aa23456") + fc.CreateFunc(eudore.FuncCreateString, "regexp=^[($") + fc.CreateFunc(eudore.FuncCreateString, "count=") + fc.CreateFunc(eudore.FuncCreateString, "count=x,x") + + mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "default")).(func(string) string)("zero") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetInt, "default")).(func(int) int)(0) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetUint, "default")).(func(uint) uint)(0) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetFloat, "default")).(func(float64) float64)(0) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "default")).(func(any) any)("zero") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "value=str")).(func(string) string)("zero") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "value=")) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "value=20060102")).(func(any) any)("zero") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "value=20060102")).(func(any) any)(time.Now()) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetInt, "add=10")).(func(int) int)(0) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetInt, "add=x")) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "add=")) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "add=10h")).(func(any) any)("") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "add=10h")).(func(any) any)(time.Now()) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "now:20060102")).(func(string) string)("now") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "now:xx")).(func(string) string)("now") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "now")).(func(any) any)(time.Time{}) + mustCreate(fc.CreateFunc(eudore.FuncCreateSetAny, "now")).(func(any) any)("now") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "replace=10,AA,aa")).(func(string) string)("AAAAA") + mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "trim=1")).(func(string) string)("1234 ") + + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidesurname")).(func(string) string)("A4")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidesurname")).(func(string) string)("eudore")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidesurname")).(func(string) string)("eudore org")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidesurname")).(func(string) string)("世界")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidename")).(func(string) string)("A4")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidename")).(func(string) string)("eudore")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidename")).(func(string) string)("eudore org")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidename")).(func(string) string)("世界")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidemail")).(func(string) string)("postmaster@eudore.cn")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidemail")).(func(string) string)("master@eudore.cn")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidemail")).(func(string) string)("root@eudore.cn")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidemail")).(func(string) string)("eudore.cn")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidephone")).(func(string) string)("15824681234")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidephone")).(func(string) string)("+86 15824681234")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidephone")).(func(string) string)("010-32221234")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hidephone")).(func(string) string)("xx010-32221234")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "hide")).(func(string) string)("pass")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "len")).(func(string) string)("123456")) + t.Log(mustCreate(fc.CreateFunc(eudore.FuncCreateSetString, "md5")).(func(string) string)("123456")) + fc.CreateFunc(eudore.FuncCreateSetInt, "value=str") +} + +func TestFuncCreatorExpr(t *testing.T) { + k := eudore.FuncCreateString + fc := eudore.NewFuncCreatorExpr() + fc.RegisterFunc("init") + + t.Log(mustCreate(fc.CreateFunc(k, "NOT zero")).(func(string) bool)("0")) + t.Log(mustCreate(fc.CreateFunc(k, "NOT(zero)")).(func(string) bool)("")) + t.Log(mustCreate(fc.CreateFunc(k, "len>7 AND domain")).(func(string) bool)("eudore.cn")) + t.Log(mustCreate(fc.CreateFunc(k, "len>7 \r\nAND(contains=xx ss)")).(func(string) bool)("xx")) + t.Log(mustCreate(fc.CreateFunc(k, "(len>7 AND contains=xxss) OR mail OR phone")).(func(string) bool)("xxssxxss")) + t.Log(mustCreate(fc.CreateFunc(k, "len>7 AND(contains=xxss OR mail) OR phone")).(func(string) bool)("xxssxxs")) + t.Log(mustCreate(fc.CreateFunc(k, "len>7 AND contains=xxss OR(mail OR phone)")).(func(string) bool)("xxss@eudore.cn")) + + for i := 0; i < 8; i++ { + fc.CreateFunc(eudore.FuncCreateKind(i), "NOT zero1") + } + fc.CreateFunc(k, "mail AND zero1") + fc.CreateFunc(k, "mail OR zero1") + fc.CreateFunc(k, " ") + fc.CreateFunc(k, "( )") + fc.CreateFunc(k, "(zero sss") + fc.CreateFunc(k, "(zero sss) AND") + fc.CreateFunc(k, "(zero sss) AND ( )") + fc.CreateFunc(k, " zero") + fc.CreateFunc(k, "(zero)") + fc.CreateFunc(k, "NOT\r\nzero") + fc.CreateFunc(k, "NOT(zero)") + fc.CreateFunc(k, " len>7 AND NOT(contains=xx ss)") + fc.CreateFunc(k, " len>7 AND(contains=xx ss) AND mail") + fc.CreateFunc(k, "(len>7 AND contains=xxss) OR mail OR phone") + fc.CreateFunc(k, " len>7 AND contains=xxss OR(mail OR phone)") + fc.CreateFunc(k, " len>7 AND(contains=xxss OR mail) OR phone") + + fc.List() + + // print meta + meta, ok := fc.(interface{ Metadata() any }).Metadata().(eudore.MetadataFuncCreator) + if ok { + for _, expr := range meta.Exprs { + t.Log("expr:", expr) + } + for _, err := range meta.Errors { + t.Log("err:", err) + } + } +} diff --git a/_example/handler_test.go b/_example/handler_test.go index 1b8a14f..fd9143d 100644 --- a/_example/handler_test.go +++ b/_example/handler_test.go @@ -1,15 +1,103 @@ package eudore_test import ( + "embed" "errors" "fmt" + "io/fs" "net/http" + "net/url" + "os" "strings" "testing" "github.com/eudore/eudore" + "github.com/eudore/eudore/middleware" ) +//go:embed *.go +var root embed.FS + +type fsPermission struct{} + +func (fsPermission) Open(name string) (http.File, error) { + return nil, os.ErrPermission +} + +type fsHTTPDir struct{} + +func (fsHTTPDir) Open(name string) (http.File, error) { + return fsHTTPFile{}, nil +} + +type fsHTTPFile struct { + http.File +} + +func (fsHTTPFile) Readdir(count int) ([]fs.FileInfo, error) { + return nil, fmt.Errorf("test error, not dir") +} + +func (fsHTTPFile) Stat() (fs.FileInfo, error) { + return os.Stat(".") +} + +func (fsHTTPFile) Close() error { + return nil +} + +func TestHandlerRoute(t *testing.T) { + os.Mkdir("static/", 0o644) + defer os.RemoveAll("static/") + os.WriteFile("static/403.js", []byte("1234567890abcdef"), 0o000) + for i := 0; i < 8; i++ { + os.WriteFile("static/index.js", []byte("1234567890abcdef"), 0o644) + } + + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyHandlerExtender, eudore.NewHandlerExtender()) + app.AddMiddleware("global", middleware.NewLoggerFunc(app, "route")) + app.AddHandler("404", "", eudore.HandlerRouter404) + app.AddHandler("405", "", eudore.HandlerRouter405) + app.GetFunc("/403", eudore.HandlerRouter403) + app.GetFunc("/index", eudore.HandlerEmpty) + app.GetFunc("/meta/*", eudore.HandlerMetadata) + app.GetFunc("/static/dir/*", eudore.NewHandlerStatic(".", ".")) + app.GetFunc("/static/index/* autoindex=true", eudore.NewHandlerStatic(".", ".")) + app.GetFunc("/static/embed/*", root) + app.GetFunc("/static/fs1/* autoindex=true", fsPermission{}) + app.GetFunc("/static/fs2/* autoindex=true", fsHTTPDir{}) + + app.NewRequest(nil, "GET", "/index") + app.NewRequest(nil, "POST", "/index") + app.NewRequest(nil, "GET", "/403") + app.NewRequest(nil, "GET", "/404") + app.NewRequest(nil, "GET", "/meta/") + app.NewRequest(nil, "GET", "/meta/app") + app.NewRequest(nil, "GET", "/meta/router") + app.NewRequest(nil, "GET", "/static/dir/app_test.go") + app.NewRequest(nil, "GET", "/static/embed/") + app.NewRequest(nil, "GET", "/static/embed/app.go") + app.NewRequest(nil, "GET", "/static/embed/app_test.go") + app.NewRequest(nil, "GET", "/static/index/") + app.NewRequest(nil, "GET", "/static/index/403.js") + app.NewRequest(nil, "GET", "/static/fs1/") + app.NewRequest(nil, "GET", "/static/fs2/") + + eudore.NewFileSystems(".", http.Dir("."), eudore.NewFileSystems(".", ".")) + + app.SetValue(eudore.ContextKeyHandlerExtender, eudore.NewHandlerExtenderTree()) + app.NewRequest(nil, "GET", "/meta/") + app.SetValue(eudore.ContextKeyHandlerExtender, eudore.NewHandlerExtenderWarp( + eudore.NewHandlerExtender(), + eudore.NewHandlerExtenderTree(), + )) + app.NewRequest(nil, "GET", "/meta/") + + app.CancelFunc() + app.Run() +} + func BindTestErr(ctx eudore.Context, i interface{}) error { if ctx.GetQuery("binderr") != "" { return errors.New("test bind error") @@ -24,10 +112,12 @@ func RenderTestErr(ctx eudore.Context, i interface{}) error { return eudore.RenderJSON(ctx, i) } -type handlerHttp1 struct{} -type handlerHttp2 struct{} -type handlerHttp3 struct{} -type handlerControler4 struct{ eudore.ControllerAutoRoute } +type ( + handlerHttp1 struct{} + handlerHttp2 struct{} + handlerHttp3 struct{} + handlerControler4 struct{ eudore.ControllerAutoRoute } +) func (handlerHttp1) HandleHTTP(eudore.Context) {} func (h handlerHttp2) CloneHandler() http.Handler { return h } @@ -118,10 +208,10 @@ func TestHandlerReister(t *testing.T) { } for i := 1; i < 14; i++ { - app.NewRequest(nil, "GET", fmt.Sprintf("/2/%d", i), eudore.NewClientQuery("binderr", "1")) + app.NewRequest(nil, "GET", fmt.Sprintf("/2/%d", i), url.Values{"binderr": {"1"}}) } for i := 1; i < 14; i++ { - app.NewRequest(nil, "GET", fmt.Sprintf("/2/%d", i), eudore.NewClientQuery("rendererr", "1")) + app.NewRequest(nil, "GET", fmt.Sprintf("/2/%d", i), url.Values{"rendererr": {"1"}}) } app.CancelFunc() @@ -141,7 +231,7 @@ func TestHandlerList(t *testing.T) { return nil }) api.AnyFunc("/user/info", "hello") - t.Log(strings.Join(api.(eudore.HandlerExtender).ListExtendHandlerNames(), "\n")) + t.Log(strings.Join(api.(eudore.HandlerExtender).List(), "\n")) app.CancelFunc() app.Run() @@ -170,12 +260,12 @@ func TestHandlerRPC(t *testing.T) { }) app.NewRequest(nil, "PUT", "/1/1") - app.NewRequest(nil, "PUT", "/1/2", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) + app.NewRequest(nil, "PUT", "/1/2", http.Header{eudore.HeaderAccept: {eudore.MimeApplicationJSON}}) app.NewRequest(nil, "PUT", "/1/2", eudore.NewClientBodyJSON(map[string]interface{}{ "name": "eudore", })) - app.NewRequest(nil, "GET", "/1/1", eudore.NewClientQuery("binderr", "1")) - app.NewRequest(nil, "GET", "/1/1", eudore.NewClientQuery("rendererr", "1")) + app.NewRequest(nil, "GET", "/1/1", url.Values{"binderr": {"1"}}) + app.NewRequest(nil, "GET", "/1/1", url.Values{"rendererr": {"1"}}) app.CancelFunc() app.Run() @@ -195,12 +285,3 @@ func TestHandlerFunc(t *testing.T) { } t.Log(len(hs)) } - -func TestHandlerStatic(t *testing.T) { - app := eudore.NewApp() - app.AnyFunc("/static/*", eudore.NewStaticHandler("", "")) - - app.NewRequest(nil, "GET", "/static/index.html") - app.CancelFunc() - app.Run() -} diff --git a/_example/handlerdata_test.go b/_example/handlerdata_test.go index f8720bf..6e60275 100644 --- a/_example/handlerdata_test.go +++ b/_example/handlerdata_test.go @@ -4,7 +4,8 @@ import ( "context" "fmt" "html/template" - "reflect" + "net/http" + "net/url" "strings" "testing" @@ -34,16 +35,24 @@ func TestHandlerDataBind(*testing.T) { return &data, nil }) - app.NewRequest(nil, "GET", "/hello", eudore.NewClientBodyString("trace"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/data/header", eudore.NewClientHeader("name", "eudore"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/data/get-url", eudore.NewClientQuery("name", "eudore"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "POST", "/data/post-url", eudore.NewClientQuery("name", "eudore"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "POST", "/data/post-mime", eudore.NewClientQuery("name", "eudore"), eudore.NewClientHeader(eudore.HeaderContentType, "pb"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "PATCH", "/data/patch-mime", eudore.NewClientQuery("name", "eudore"), eudore.NewClientHeader(eudore.HeaderContentType, "pb"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "DELETE", "/data/detele-mime", eudore.NewClientQuery("name", "eudore"), eudore.NewClientHeader(eudore.HeaderContentType, "pb"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "PUT", "/data/json", eudore.NewClientBodyJSONValue("name", "eudore"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "PUT", "/data/json", eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationXML), eudore.NewClientBodyString("eudore"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "PUT", "/data/form", eudore.NewClientBodyFormValue("name", "eudore"), eudore.NewClientCheckStatus(200)) + form := eudore.NewClientBodyForm(nil) + form.AddFile("file", "app", []byte("form body")) + + app.NewRequest(nil, "GET", "/hello", strings.NewReader("trace")) + app.NewRequest(nil, "GET", "/data/header", http.Header{"X-Name": {"eudore"}}) + app.NewRequest(nil, "GET", "/data/get-url", url.Values{"name": {"eudore"}}) + app.NewRequest(nil, "POST", "/data/post-url", url.Values{"name": {"eudore"}}) + app.NewRequest(nil, "POST", "/data/post-mime", url.Values{"name": {"eudore"}}, http.Header{eudore.HeaderContentType: {"pb"}}) + app.NewRequest(nil, "PATCH", "/data/patch-mime", url.Values{"name": {"eudore"}}, http.Header{eudore.HeaderContentType: {"pb"}}) + app.NewRequest(nil, "DELETE", "/data/detele-mime", url.Values{"name": {"eudore"}}, http.Header{eudore.HeaderContentType: {"pb"}}) + app.NewRequest(nil, "PUT", "/data/json", eudore.NewClientBodyJSON(url.Values{"name": {"eudore"}})) + app.NewRequest(nil, "PUT", "/data/xml", eudore.NewClientBodyXML(&Data{"eudore"})) + app.NewRequest(nil, "PUT", "/data/url", eudore.NewClientBodyForm(url.Values{"name": {"eudore"}})) + app.NewRequest(nil, "PUT", "/data/form", form) + app.NewRequest(nil, "PUT", "/data/protobuf", http.Header{eudore.HeaderContentType: {eudore.MimeApplicationProtobuf}}) + + app.Values = nil + eudore.NewContextBasePool(app) app.CancelFunc() app.Run() @@ -55,14 +64,20 @@ func TestHandlerDataRender(*testing.T) { } app := eudore.NewApp() - app.SetValue(eudore.ContextKeyFilte, func(ctx eudore.Context, i interface{}) error { + app.SetValue(eudore.ContextKeyFilter, func(ctx eudore.Context, i interface{}) error { if ctx.Path() == "/err" { return fmt.Errorf("filte error") } return nil }) app.SetValue(eudore.ContextKeyContextPool, eudore.NewContextBasePool(app)) - app.AnyFunc("/data/* template=data", func(ctx eudore.Context) interface{} { + app.AnyFunc("/data/*", func(ctx eudore.Context) interface{} { + return &Data{"eudore"} + }) + app.AnyFunc("/html/err", func(ctx eudore.Context) interface{} { + return &struct{ Name func() }{} + }) + app.AnyFunc("/html/* template=data", func(ctx eudore.Context) interface{} { return &Data{"eudore"} }) app.AnyFunc("/text/stringer", func(ctx eudore.Context) interface{} { @@ -72,213 +87,209 @@ func TestHandlerDataRender(*testing.T) { return "text/string" }) - app.NewRequest(nil, "GET", "/err", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextPlain)) - app.NewRequest(nil, "GET", "/data/text", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextPlain)) - app.NewRequest(nil, "GET", "/data/json", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) - app.NewRequest(nil, "GET", "/data/xml", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationXML)) - app.NewRequest(nil, "GET", "/data/html", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML)) - eudore.DefaultRenderHTMLTemplate = nil - app.NewRequest(nil, "GET", "/data/html", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML)) + accept := func(val string) http.Header { + return http.Header{eudore.HeaderAccept: {val}} + } + + app.NewRequest(nil, "GET", "/err", accept(eudore.MimeTextPlain)) + app.NewRequest(nil, "GET", "/data/quality", accept(eudore.MimeTextPlain+";q=0")) + app.NewRequest(nil, "GET", "/data/text", accept(eudore.MimeTextPlain)) + app.NewRequest(nil, "GET", "/data/json", accept(eudore.MimeApplicationJSON)) + app.NewRequest(nil, "GET", "/data/xml", accept(eudore.MimeApplicationXML)) + app.NewRequest(nil, "GET", "/data/html", accept(eudore.MimeTextHTML)) + app.NewRequest(nil, "GET", "/data/protobuf", accept(eudore.MimeApplicationProtobuf)) + app.NewRequest(nil, "GET", "/html/err", accept(eudore.MimeTextHTML)) + app.NewRequest(nil, "GET", "/html/html", accept(eudore.MimeTextHTML)) app.NewRequest(nil, "GET", "/data/accept") - app.NewRequest(nil, "GET", "/text/stringer", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextPlain)) - app.NewRequest(nil, "GET", "/text/string", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextPlain)) - app.NewRequest(nil, "GET", "/text/string", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) - app.NewRequest(nil, "GET", "/text/string", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSONCharsetUtf8)) + + app.NewRequest(nil, "GET", "/text/stringer", accept(eudore.MimeTextPlain)) + app.NewRequest(nil, "GET", "/text/string", accept(eudore.MimeTextPlain)) + app.NewRequest(nil, "GET", "/text/string", accept(eudore.MimeApplicationJSON)) + app.NewRequest(nil, "GET", "/text/string", accept(eudore.MimeApplicationJSONCharsetUtf8)) temp, _ := template.New("").Parse(`{{- define "data" -}} Data Name is {{.Name}} {{- end -}}`) app.SetValue(eudore.ContextKeyTemplate, temp) + app.NewRequest(nil, "GET", "/data/html", accept(eudore.MimeTextHTML)) + app.NewRequest(nil, "GET", "/text/string", accept(eudore.MimeTextHTML)) + app.NewRequest(nil, "GET", "/text/string", accept(eudore.MimeTextHTML)) - app.NewRequest(nil, "GET", "/data/html", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML)) - app.NewRequest(nil, "GET", "/text/string", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML)) - app.NewRequest(nil, "GET", "/text/string", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML)) + app.SetValue(eudore.ContextKeyTemplate, nil) + app.NewRequest(nil, "GET", "/data/html", accept(eudore.MimeTextHTML)) app.CancelFunc() app.Run() } -func TestFuncCreator(*testing.T) { - type Data struct { - Name string `json:"name" xml:"name"` - } - - app := eudore.NewApp() - - fc := eudore.NewFuncCreator() - app.SetValue(eudore.ContextKeyFuncCreator, fc) - // register error - fc.Register("test", "not func", TestFuncCreator, func(string) func(string) { - return nil - }, func(string) func(string) bool { - return nil - }, func(string) func(string) string { - return nil - }) - - var fn interface{} - var err error - typeInt := reflect.TypeOf((*int)(nil)).Elem() - typeString := reflect.TypeOf((*string)(nil)).Elem() - typeInterface := reflect.TypeOf((*interface{})(nil)).Elem() - - fc.Create(typeInt, "nozero") - fc.Create(typeInt, "nozero:") - { - // validateIntNozero - fn, _ = fc.Create(typeInt, "nozero") - app.Info(fn.(func(int) bool)(0)) - } - { - // validateStringNozero - fn, _ = fc.Create(typeString, "nozero") - app.Info(fn.(func(string) bool)("123456")) - } - { - // validateInterfaceNozero - fn, _ = fc.Create(typeInterface, "nozero") - app.Info(fn.(func(interface{}) bool)("123456")) - app.Info(fn.(func(interface{}) bool)([]int{1, 2, 3, 4, 5, 6})) - } - { - // validateStringIsnum - fn, _ = fc.Create(typeString, "isnum") - app.Info(fn.(func(string) bool)("234")) - app.Info(fn.(func(string) bool)("xx2")) - } - { - // validateNewIntMin - fc.Create(typeInt, "min=xx") - fn, _ = fc.Create(typeInt, "min=033") - app.Info(fn.(func(int) bool)(12)) - app.Info(fn.(func(int) bool)(644)) - } - { - // validateNewIntMax - fc.Create(typeInt, "max") - fc.Create(typeInt, "max=xx") - fn, _ = fc.Create(typeInt, "max=033") - app.Info(fn.(func(int) bool)(12)) - app.Info(fn.(func(int) bool)(644)) - } - { - // validateNewStringMin - fc.Create(typeString, "min=xx") - fn, _ = fc.Create(typeString, "min=033") - app.Info(fn.(func(string) bool)("12")) - app.Info(fn.(func(string) bool)("644")) - app.Info(fn.(func(string) bool)("xx")) - } - { - // validateNewStringMax - fc.Create(typeString, "max=xx") - fn, _ = fc.Create(typeString, "max=033") - app.Info(fn.(func(string) bool)("12")) - app.Info(fn.(func(string) bool)("644")) - app.Info(fn.(func(string) bool)("xx")) - } - { - // validateNewStringLen - fc.Create(typeString, "len>x") - fn, _ = fc.Create(typeString, "len>5") - app.Info(fn.(func(string) bool)("8812988")) - app.Info(fn.(func(string) bool)("123")) - app.Info(fn.(func(string) bool)("123456")) - fn, _ = fc.Create(typeString, "len<5") - app.Info(fn.(func(string) bool)("8812988")) - app.Info(fn.(func(string) bool)("123")) - app.Info(fn.(func(string) bool)("123456")) - fn, _ = fc.Create(typeString, "len=5") - app.Info(fn.(func(string) bool)("8812988")) - app.Info(fn.(func(string) bool)("123")) - app.Info(fn.(func(string) bool)("123456")) - } - { - // validateNewInterfaceLen - fc.Create(typeInterface, "len=.") - fn, _ = fc.Create(typeInterface, "len>4") - app.Info(fn.(func(interface{}) bool)("123456")) - app.Info(fn.(func(interface{}) bool)([]int{1, 2, 3, 4})) - app.Info(fn.(func(interface{}) bool)(6)) - fn, _ = fc.Create(typeInterface, "len<4") - app.Info(fn.(func(interface{}) bool)("123456")) - app.Info(fn.(func(interface{}) bool)([]int{1, 2, 3, 4})) - app.Info(fn.(func(interface{}) bool)(6)) - fn, _ = fc.Create(typeInterface, "len=4") - app.Info(fn.(func(interface{}) bool)("123456")) - app.Info(fn.(func(interface{}) bool)([]int{1, 2, 3, 4})) - app.Info(fn.(func(interface{}) bool)(6)) - } - { - // validateNewStringRegexp - _, err = fc.Create(typeString, "regexp^[($") - app.Info(err) - fn, _ = fc.Create(typeString, "regexp^\\d+$") - app.Info(fn.(func(string) bool)("123456")) - } +type dataValidate01 struct { + ID *int `json:"id" xml:"id" validate:"nozero,omitempty"` + Child []int `json:"child" xml:"child" validate:"nozero,omitempty"` + Name string `json:"name" xml:"name" validate:"nozero,len>4"` + Level1 string `json:"level1" xml:"level1" validate:"-"` +} +type dataValidate02 struct { + ID *int `json:"id" xml:"id" validate:"nozero"` + Name string `json:"name" xml:"name" validate:"len>4"` + Level1 string `json:"level1" xml:"level1"` +} +type dataValidate03 struct { + ID int `json:"id" xml:"id" validate:"(nozero),,"` +} +type dataValidate04 struct { + ID int `json:"id" xml:"id" validate:"not"` +} +type dataValidate05 struct { + dataValidate03 + *dataValidate04 +} - app.CancelFunc() - app.Run() +func (dataValidate03) Validate(context.Context) error { + return fmt.Errorf("test error validate") } func TestHandlerDataValidateField(*testing.T) { - eudore.NewValidateField(context.Background()) - type DataValidate01 struct { - Name string `json:"name" xml:"name" validate:"len>4"` - Email string `json:"email" xml:"email" validate:"email"` - Phone string `json:"phone" xml:"phone" validate:"phone"` - Level1 string `json:"level1" xml:"level1" validate:""` - } - - type DataValidate02 struct { - Name string `json:"name" xml:"name" validate:"len>4"` - Email string `json:"email" xml:"email" validate:"email"` - Level1 string `json:"level1" xml:"level1" validate:"is"` - } - app := eudore.NewApp() - fc := eudore.NewFuncCreator() - app.SetValue(eudore.ContextKeyFuncCreator, fc) - app.SetValue(eudore.ContextKeyValidate, eudore.NewValidateField(app)) + app.SetValue(eudore.ContextKeyValidater, eudore.NewValidateField(app)) app.SetValue(eudore.ContextKeyContextPool, eudore.NewContextBasePool(app)) - fc.Register("email", func(email string) bool { - return strings.HasSuffix(email, "@eudore.cn") + app.AddMiddleware(middleware.NewLoggerFunc(app)) + app.AnyFunc("/data/struct1", func(ctx eudore.Context) { + var data dataValidate01 + ctx.Bind(&data) }) - fc.Register("phone", func(phone string) bool { - return len(phone) == 11 + app.AnyFunc("/data/slice1", func(ctx eudore.Context) { + var data []dataValidate01 + ctx.Bind(&data) }) - - app.AnyFunc("/data/1", func(ctx eudore.Context) { - var data DataValidate01 + app.AnyFunc("/data/ptr1", func(ctx eudore.Context) { + var data []*dataValidate01 ctx.Bind(&data) }) - app.AnyFunc("/data/2", func(ctx eudore.Context) { - var data []DataValidate01 + app.AnyFunc("/data/any1", func(ctx eudore.Context) { + var data []any = []any{new(dataValidate01)} ctx.Bind(&data) }) - app.AnyFunc("/data/3", func(ctx eudore.Context) { - var data DataValidate02 + app.AnyFunc("/data/struct2", func(ctx eudore.Context) { + var data dataValidate02 ctx.Bind(&data) }) - app.AnyFunc("/data/4", func(ctx eudore.Context) { - var data []DataValidate02 + app.AnyFunc("/data/struct3", func(ctx eudore.Context) { + var data dataValidate03 ctx.Bind(&data) }) - app.AnyFunc("/data/5", func(ctx eudore.Context) { - var data []*DataValidate02 + app.AnyFunc("/data/struct4", func(ctx eudore.Context) { + var data dataValidate04 ctx.Bind(&data) }) - app.AnyFunc("/data/7", func(ctx eudore.Context) { - var data map[string]interface{} + app.AnyFunc("/data/struct5", func(ctx eudore.Context) { + var data dataValidate05 ctx.Bind(&data) }) - app.NewRequest(nil, "POST", "/data/1", eudore.NewClientBodyJSON(&DataValidate01{Name: "eudore", Email: "postmaster@eudore.cn", Phone: "15512344321"})) - app.NewRequest(nil, "POST", "/data/2", eudore.NewClientBodyJSON([]DataValidate01{{Name: "eudore"}})) - app.NewRequest(nil, "POST", "/data/3", eudore.NewClientBodyJSON(&DataValidate02{Name: "eudore"})) - app.NewRequest(nil, "POST", "/data/4", eudore.NewClientBodyJSON([]*DataValidate02{{Name: "eudore"}})) - app.NewRequest(nil, "POST", "/data/5", eudore.NewClientBodyJSON([]*DataValidate02{{Name: "eudore"}, {Name: "eudore"}})) - app.NewRequest(nil, "POST", "/data/7", eudore.NewClientBodyJSON(&DataValidate02{Name: "eudore"})) + fn := func(name string, val any) { + app.NewRequest(nil, "POST", "/data/"+name, eudore.NewClientBodyJSON(val)) + } + id := 4 + fn("struct1", dataValidate01{}) + fn("struct1", dataValidate01{ID: &id}) + fn("slice1", []dataValidate01{{Name: "A1"}}) + fn("slice1", []dataValidate01{{Name: "eudore", Child: []int{0, 0, 0}}}) + fn("slice1", []dataValidate01{{Name: "eudore", Child: []int{1, 2, 3}}}) + fn("ptr1", []dataValidate01{{Name: "eudore"}}) + fn("any1", []dataValidate01{{Name: "eudore"}}) + fn("struct2", dataValidate02{}) + fn("struct5", dataValidate03{ID: 32}) + fn("struct4", dataValidate04{}) + fn("struct4", dataValidate04{}) + fn("struct3", dataValidate05{}) + + app.CancelFunc() + app.Run() +} + +func TestHandlerDataFilterRule(*testing.T) { + type LoggerConfig struct { + Stdout bool `json:"stdout" xml:"stdout" alias:"stdout"` + Path string `json:"path" xml:"path" alias:"path"` + Handlers []any `json:"-" xml:"-" alias:"handlers"` + Chan chan int + } + type FilterType struct { + String string `alias:"string"` + Int int `alias:"int"` + Uint uint `alias:"uint"` + Float float64 `alias:"float"` + Bool bool `alias:"bool"` + Any any `alias:"any"` + } + + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyRender, eudore.HandlerDataFunc(func(ctx eudore.Context, i any) error { + ctx.Debugf("%#v", i) + return nil + })) + fc := eudore.NewFuncCreator() + app.SetValue(eudore.ContextKeyFuncCreator, fc) + app.SetValue(eudore.ContextKeyFilter, eudore.NewFilterRules(app)) + app.SetValue(eudore.ContextKeyContextPool, eudore.NewContextBasePool(app)) + app.AddMiddleware(middleware.NewLoggerFunc(app)) + + app.GetFunc("/data", func(ctx eudore.Context) { + data := []any{ + []string{"path=zero"}, + &LoggerConfig{Stdout: true}, + &eudore.FilterData{ + Name: "*", + Checks: []string{"path=zero"}, + Modifys: []string{"stdout=value:true"}, + }, + &LoggerConfig{}, + &eudore.FilterData{Name: "LoggerConfig", Package: "eudore", Checks: []string{"path=zero"}}, + &LoggerConfig{}, + &eudore.FilterData{Name: "Logger*", Checks: []string{"path=zero"}}, + &LoggerConfig{}, + &eudore.FilterData{Name: "App", Checks: []string{"path=zero"}}, + &LoggerConfig{}, + &eudore.FilterData{Name: "App*", Checks: []string{"path=zero"}}, + &LoggerConfig{}, + &eudore.FilterData{Name: "Logger*1", Checks: []string{"path=zero"}}, + &LoggerConfig{}, + []eudore.FilterData{{Checks: []string{"path=zero"}}}, + &LoggerConfig{}, + &eudore.FilterData{Checks: []string{"path=zero"}}, + []*LoggerConfig{{}, {Path: "app.log"}}, + &eudore.FilterData{Checks: []string{"path=zero"}}, + []any{LoggerConfig{}, &LoggerConfig{Path: "app.log"}}, + []string{"link=zero"}, + &LoggerConfig{}, + []string{"Chan=k"}, + &LoggerConfig{}, + []string{"handlers=k"}, + &LoggerConfig{}, + &eudore.FilterData{ + Checks: []string{"string=zero", "int=zero", "uint=zero", "float=zero", "bool=zero", "any=zero"}, + Modifys: []string{"string=now:20060102", "int=value:4", "uint=value:4", "float=value:4", "bool=value:true", "any=now"}, + }, + &FilterType{}, + &eudore.FilterData{ + Modifys: []string{"any=default"}, + }, + &FilterType{}, + } + + for i := 0; i < len(data); i += 2 { + ctx.SetValue(eudore.ContextKeyFilterRules, data[i]) + ctx.Render(data[i+1]) + } + }) + app.NewRequest(nil, "GET", "/data") + + meta, ok := fc.(interface{ Metadata() any }).Metadata().(eudore.MetadataFuncCreator) + if ok { + for _, err := range meta.Errors { + app.Debug("err:", err) + } + } app.CancelFunc() app.Run() } diff --git a/_example/logger_test.go b/_example/logger_test.go index 6344899..16b3c68 100644 --- a/_example/logger_test.go +++ b/_example/logger_test.go @@ -15,122 +15,269 @@ import ( "github.com/eudore/eudore" ) -func TestLogger(t *testing.T) { - log := eudore.NewApp() - log.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) - - log.SetLevel(eudore.LoggerFatal) - log.Debug("0") - log.Debugf("0") - log.Info("1") - log.Infof("1") - log.Warning("2") - log.Warningf("2") - log.Error("3") - log.Errorf("3") - log.Fatal("4") - log.Fatalf("4") - - log.SetLevel(eudore.LoggerDebug) - log.Debug("0") - log.Debugf("0") - log.Info("1") - log.Infof("1") - log.Warning("2") - log.Warningf("2") - log.Error("3") - log.Errorf("3") - log.Fatal("4") - log.Fatalf("4") - - log.WithField("key", "field").Debug("0") - log.WithField("key", "field").Debugf("0") - log.WithField("key", "field").Info("1") - log.WithField("key", "field").Infof("1") - log.WithField("key", "field").Warning("2") - log.WithField("key", "field").Warningf("2") - log.WithField("key", "field").Error("3") - log.WithField("key", "field").Errorf("3") - log.WithField("key", "field").Fatal("4") - log.WithField("key", "field").Fatalf("4") - - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Debug("0") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Debugf("0") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Info("1") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Infof("1") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Warning("2") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Warningf("2") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Error("3") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Errorf("3") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Fatal("4") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Fatalf("4") - - log.Sync() - // 设置logger - log.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerStd(nil)) - - log.SetLevel(eudore.LoggerDebug) - log.Debug("0") - log.Debugf("0") - log.Info("1") - log.Infof("1") - log.Warning("2") - log.Warningf("2") - log.Error("3") - log.Errorf("3") - log.Fatal("4") - log.Fatalf("4") - - log.WithField("key", "field").Debug("0") - log.WithField("key", "field").Debugf("0") - log.WithField("key", "field").Info("1") - log.WithField("key", "field").Infof("1") - log.WithField("key", "field").Warning("2") - log.WithField("key", "field").Warningf("2") - log.WithField("key", "field").Error("3") - log.WithField("key", "field").Errorf("3") - log.WithField("key", "field").Fatal("4") - log.WithField("key", "field").Fatalf("4") - - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Debug("0") - log.WithFields([]string{"key", "k2"}, []interface{}{"Fields"}).Debug("0") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Debugf("0") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Info("1") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Infof("1") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Warning("2") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Warningf("2") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Error("3") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Errorf("3") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Fatal("4") - log.WithFields([]string{"key"}, []interface{}{"Fields"}).Fatalf("4") - - log.WithField("depth", "stack").Info("1") - - log.Sync() - eudore.DefaultLoggerNull.Sync() - - log.CancelFunc() - log.Run() +func TestLoggerStd(t *testing.T) { + app := eudore.NewApp() + + app.GetLevel() + app.SetLevel(eudore.LoggerFatal) + app.Debug("0") + app.Debugf("0") + app.Info("1") + app.Infof("1") + app.Warning("2") + app.Warningf("2") + app.Error("3") + app.Errorf("3") + app.Fatal("4") + app.Fatalf("4") + + app.SetLevel(eudore.LoggerDebug) + app.Debug("0") + app.Debugf("0") + app.Info("1") + app.Infof("1") + app.Warning("2") + app.Warningf("2") + app.Error("3") + app.Errorf("3") + app.Fatal("4") + app.Fatalf("4") + + app.WithField("depth", "enable").Info("1") + app.WithField("depth", "stack").Info("1") + app.WithField("depth", "disable").Info("1") + app.WithField("depth", -2).WithField("depth", "enable").Info("1") + app.WithField("depth", true).Info("1") + app.WithField("context", app).WithField("context", app.Context).Info("1") + app.WithField("caller", "logger").WithField("logger", true).Info("1") + app.WithFields([]string{"key"}, []any{}).Info("1") + + app.Logger.(interface{ Metadata() any }).Metadata() + eudore.NewLoggerWithContext(app) + eudore.NewLoggerWithContext(context.Background()) + + app.CancelFunc() + app.Run() } -func TestLoggerInit(t *testing.T) { - log := eudore.NewApp() - log.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) - log.Info("loggerInit to end") - log.CancelFunc() - log.Run() +func TestLoggerInit1(t *testing.T) { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) + + app.Debug("0") + app.Infof("1") + app.Warning("2") + app.Error("3") + app.Fatal("4") + app.Logger.(interface{ Metadata() any }).Metadata() + + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(nil)) + app.CancelFunc() + app.Run() } -func TestLoggerOption(t *testing.T) { - log := eudore.NewLoggerStd(nil) +func TestLoggerInit2(t *testing.T) { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) + app.Info("loggerInit to end") + app.CancelFunc() + app.Run() +} - // logger depth - log = log.WithField("depth", "enable").WithField("context", context.Background()).WithField("caller", "log depth").WithField("logger", true) - log.Info("file line") - log.WithField("depth", "disable").Info("file line") - log.WithField("depth", []string{"disable"}).Info("file line") +func TestLoggerFormatterText(t *testing.T) { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(&eudore.LoggerConfig{ + Caller: true, + Stdout: true, + Formatter: "text", + })) + loggerWriteData(app) - log.WithField("context", context.TODO()).Info("logger context") + app.CancelFunc() + app.Run() +} + +func TestLoggerFormatterJSON(t *testing.T) { + app := eudore.NewApp() + loggerWriteData(app) + app.WithField("utf8", "世界\\ \n \r \t \002 \321 \u2028").Debug() + app.WithField("field", new(marsha1)).Debug("marsha1") + app.WithField("field", new(marsha2)).Debug("marsha2") + app.WithField("field", new(marsha3)).Debug("marsha3") + app.WithField("field", new(marsha4)).Debug("marsha4") + app.WithField("field", new(marsha5)).Debug("marsha5") +} + +type marsha1 struct{} +type marsha2 struct{} +type marsha3 struct{} +type marsha4 struct{} +type marsha5 bool + +func (marsha1) MarshalJSON() ([]byte, error) { + return []byte("\"marsha1\""), nil +} +func (marsha2) MarshalJSON() ([]byte, error) { + return []byte("\"marsha2\""), errors.New("test marshal error") +} +func (marsha3) MarshalText() ([]byte, error) { + return []byte("\\ \n \r \t \002 \321 世界"), nil +} +func (marsha4) MarshalText() ([]byte, error) { + return []byte("\\ \n \r \t \002 \321 世界"), errors.New("\\ \n \r \t \002 \321 世界") +} +func (marsha5) Method() {} + +type StructCycle struct { + Name string `json:"name,omitempty"` + Err error + *StructCycle +} + +type StructAnon struct { + Duration *eudore.TimeDuration + Now *time.Time + *eudore.LoggerConfig + eudore.ServerConfig +} + +func loggerWriteData(log eudore.Logger) { + type M map[string]any + var eptr *time.Time + var eany any + var cslice []any + var cmap = make(map[string]any) + var cycle = &StructCycle{} + var echan chan int + var dura eudore.TimeDuration + cslice = append(cslice, "slice", 0, cslice) + cmap["data"] = "map data" + cmap["this"] = cmap + cycle.StructCycle = cycle + + log = log.WithField("depth", "disable").WithField("logger", true) + log.WithField("json", eudore.LoggerDebug).Debug() + log.WithField("fmt.Stringer", eudore.HandlerFunc(eudore.HandlerEmpty)).Debug() + log.WithField("fmt.Stringer", &dura).Debug() + log.WithField("error", fmt.Errorf("logger wirte error")).Debug() + log.WithField("bool", true).Debug() + log.WithField("int", 1).Debug() + log.WithField("uint", uint(2)).Debug() + log.WithField("float", 3.3).Debug() + log.WithField("complex", complex(4.1, 4.2)).Debug() + log.WithField("map", map[string]int{"a": 1, "b": 2}).Debug() + log.WithField("map alias", M{"a": 1, "b": 2}).Debug() + log.WithField("map empty", map[string]int{}).Debug() + log.WithField("map cycle", cmap).Debug() + log.WithField("struct", struct{ Name string }{"name"}).Debug() + log.WithField("struct empty", struct{}{}).Debug() + log.WithField("struct cycle", cycle).Debug() + log.WithField("struct anonymous", &StructAnon{}).Debug() + log.WithField("ptr", &struct{ Name string }{"name"}).Debug() + log.WithField("ptr empty", eptr).Debug() + log.WithField("slice empty", cslice[0:0]).Debug() + log.WithField("slice cycle", cslice).Debug() + log.WithField("array", []int{1, 2, 3}).Debug() + log.WithField("func", eudore.NewApp).Debug() + log.WithField("bytes", []byte("bytes")).Debug() + log.WithField("any empty", eany).Debug() + log.WithField("any empty", []any{eany}).Debug() + log.WithField("chan empty", echan).Debug() + log.WithField("depth", "disable").Info("depth") + log.WithField("depth", "enable").Info("depth") + log.WithField("depth", "stack").Info("depth") +} + +type logConfig struct { + Level1 eudore.LoggerLevel `alias:"level" json:"level1"` + Level2 eudore.LoggerLevel `alias:"level2" json:"level2"` + Level3 eudore.LoggerLevel `alias:"level3" json:"level3"` +} + +func TestLoggerLevel(t *testing.T) { + conf := &logConfig{} + jsonBlob := []byte(`{"level1":"1","level2":"info","level3":"3"}`) + err := json.Unmarshal(jsonBlob, conf) + t.Logf("%v\t%#v\n", err, conf) + jsonBlob = []byte(`{"level3": "33"}`) + err = json.Unmarshal(jsonBlob, conf) + t.Logf("%v\t%#v\n", err, conf) + json.Marshal(conf) + t.Log(conf.Level1) +} + +func TestLoggerHookFatal(t *testing.T) { + defer func() { + t.Logf("LoggerHookFatal recover %v", recover()) + }() + + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + HookFatal: true, + })) + + app.Fatal("stop app") + app.Run() + + log := eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + HookFatal: true, + }) + log.Fatal("stop logger") +} + +func TestLoggerHookFilter(t *testing.T) { + app := eudore.NewApp() + fc := eudore.NewFuncCreator() + app.SetValue(eudore.ContextKeyFuncCreator, fc) + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + HookFilter: [][]string{ + {"path string prefix=/static/", "status int equal:200"}, + {"mail setstring hidemail", "password setstring hide"}, + {"ti setint default", "tu setuint default", "tf setfloat default", "tb setbool default", "ta setany default"}, + {"1 1 1"}, + {"strs setstring default", "t setany add=240h"}, + }, + })) + + app.WithFields([]string{"path", "status"}, []any{"/index", 200}).Info() + app.WithFields([]string{"path", "status"}, []any{"/static/index", 200}).Info() + app.WithFields([]string{"path"}, []any{true}).Info() + app.WithField("mail", "postmaster@eudore.cn").WithField("password", "123456").Info() + app.WithFields([]string{"ti", "tu", "tf", "tb", "ta"}, []any{1, uint(1), 1.0, true, time.Now()}).Info() + app.WithFields([]string{"path"}, []any{nil}).Info() + app.WithFields([]string{"strs", "t"}, []any{[]string{"", "2"}, time.Now()}).Debug() + + meta, ok := fc.(interface{ Metadata() any }).Metadata().(eudore.MetadataFuncCreator) + if ok { + for _, err := range meta.Errors { + app.Debug("err:", err) + } + } + + time.Sleep(1000 * time.Millisecond) + app.CancelFunc() + app.Run() +} + +func TestLoggerWriterStdout(t *testing.T) { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + StdColor: true, + })) + app.Info("color") + + eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + StdColor: true, + TimeFormat: time.RFC3339Nano + " " + time.RFC3339Nano, + }).Debug("disable color") + + app.CancelFunc() + app.Run() } func TestLoggerWriterFile(t *testing.T) { @@ -140,251 +287,101 @@ func TestLoggerWriterFile(t *testing.T) { // file logfile := "tmp-loggerStd.log" - log := eudore.NewLoggerStd(&eudore.LoggerStdConfig{ + log := eudore.NewLogger(&eudore.LoggerConfig{ Path: logfile, }) defer os.Remove(logfile) - - log.Info("hello") - log.Sync() - os.Remove(logfile) - - // file and std - log = eudore.NewLoggerStd(&eudore.LoggerStdConfig{ - Std: true, - Path: logfile, - }) log.Info("hello") - log.Sync() - os.Remove(logfile) // create error func() { defer func() { t.Logf("NewLoggerWriterFile recover %v", recover()) }() - log = eudore.NewLoggerStd(&eudore.LoggerStdConfig{ - Path: "out-yyyy-MM-dd-HH-index---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------.log", + log = eudore.NewLogger(&eudore.LoggerConfig{ + Path: "out----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------/1.log", }) }() func() { defer func() { t.Logf("NewLoggerWriterFile recover %v", recover()) }() - log = eudore.NewLoggerStd(&eudore.LoggerStdConfig{ + log = eudore.NewLogger(&eudore.LoggerConfig{ Path: "out----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------.log", }) }() - log.Info("hello") - log.Sync() } func TestNewLoggerWriterRotate(t *testing.T) { defer os.RemoveAll("logger") - defer os.RemoveAll("logger2") { // 占用一个索引文件 rotate跳过 os.Mkdir("logger", 0644) - file, err := os.OpenFile("logger/logger-out-2.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + file, err := os.OpenFile("logger/app-2.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) if err == nil { str := []byte("eudore logger Writer test.") for i := 0; i < 1024; i++ { file.Write(str) } - file.Sync() file.Close() } } // date - log := eudore.NewLoggerStd(&eudore.LoggerStdConfig{ - Std: true, - Path: "logger2/logger-yyyy-MM-dd-HH-index.log", + log := eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + Path: "logger/app-yyyy-mm-dd-hh.log", }) log.Info("hello") - // no link - log = eudore.NewLoggerStd(&eudore.LoggerStdConfig{ - Std: true, - Path: "logger/logger-out-index.log", + // size + log = eudore.NewLogger(&eudore.LoggerConfig{ + Stdout: true, + Path: "logger/app-size.log", MaxSize: 16 << 10, }) log.Info("hello") - // rotate file - log = eudore.NewLoggerStd(&eudore.LoggerStdConfig{ - Path: "logger/logger-out-index.log", - MaxSize: 16 << 10, - Link: "logger/app.log", + // link + log = eudore.NewLogger(&eudore.LoggerConfig{ + Path: "logger/app.log", + Link: "logger/app.log", + MaxSize: 16 << 10, + MaxCount: 3, }) log.Info("hello") - log = log.WithFields([]string{"name", "type"}, []interface{}{"eudore", "logger"}).WithField("logger", true) + log = log.WithFields([]string{"name", "type"}, []any{"eudore", "logger"}).WithField("logger", true) for i := 0; i < 1000; i++ { log.Info("test rotate") } - log.Sync() -} - -type ( - marsha1 struct{} - marsha2 struct{} - marsha3 struct{} - marsha4 struct{} - marsha5 struct{ Num []int } -) - -func (marsha1) MarshalJSON() ([]byte, error) { - return []byte("marsha1"), nil -} -func (marsha2) MarshalJSON() ([]byte, error) { - return []byte("marsha2"), errors.New("test marshal error") -} -func (marsha3) MarshalText() ([]byte, error) { - return []byte("\\ \n \r \t \002 \321 世界"), nil -} -func (marsha4) MarshalText() ([]byte, error) { - return []byte("\\ \n \r \t \002 \321 世界"), errors.New("\\ \n \r \t \002 \321 世界") -} - -func TestLoggerStdJSON(t *testing.T) { - var ptr *time.Time - var slice = []int{1} - log := eudore.NewLoggerStd(nil) - log.Debug("debug") - log.Info("info") - log.WithField("json", eudore.LoggerDebug).Debug() - log.WithField("stringer", eudore.HandlerFunc(eudore.HandlerEmpty)).Debug("2") - log.WithField("error", fmt.Errorf("logger wirte error")).Debug() - log.WithField("bool", true).Debug("2") - log.WithField("int", 1).Debug("2") - log.WithField("uint", uint(2)).Debug("2") - log.WithField("float", 3.3).Debug("2") - log.WithField("complex", complex(4.1, 4.2)).Debug("2") - log.WithField("array", []int{1, 2, 3}).Debug("2") - log.WithField("map", map[string]int{"a": 1, "b": 2}).Debug("2") - log.WithField("struct", struct{ Name string }{"name"}).Debug("2") - log.WithField("struct", struct{}{}).Debug("2") - log.WithField("ptr", &struct{ Name string }{"name"}).Debug("2") - log.WithField("ptr", ptr).Debug("2") - log.WithField("slice", slice[0:0]).Debug("2") - log.WithField("emptry face", []interface{}{ptr}).Debug("2") - log.WithField("func", TestLoggerStdJSON).Debug("2") - log.WithField("bytes", []byte("bytes")).Debug("2") - var i interface{} - log.WithField("nil", i).Debug("2") - - log.WithField("utf8 string", "\\ \n \r \t \002 \321 世界").Debug("2") - log.WithField("utf8 bytes", []byte("\\ \n \r \t \002 \321 世界")).Debug("2") - - log.WithField("nil", new(marsha1)).Debug("marsha1") - log.WithField("nil", new(marsha2)).Debug("marsha2") - log.WithField("nil", new(marsha3)).Debug("marsha3") - log.WithField("nil", new(marsha4)).Debug("marsha4") - log.WithField("nil", new(marsha5)).Debug("marsha5") - - log.Sync() -} - -type logConfig struct { - Level1 eudore.LoggerLevel `alias:"level" json:"level1"` - Level2 eudore.LoggerLevel `alias:"level2" json:"level2"` - Level3 eudore.LoggerLevel `alias:"level3" json:"level3"` -} - -func TestLoggerLevel(t *testing.T) { - conf := &logConfig{} - var jsonBlob = []byte(`{"level1":"1","level2":"info","level3":"3"}`) - err := json.Unmarshal(jsonBlob, conf) - t.Logf("%v\t%#v\n", err, conf) - jsonBlob = []byte(`{"level3": "33"}`) - err = json.Unmarshal(jsonBlob, conf) - t.Logf("%v\t%#v\n", err, conf) -} - -type loggerStdData016 struct { - eudore.LoggerStdData - meta *loggerStdData016Meta -} - -type loggerStdData016Meta struct { - debug int64 - info int64 - warning int64 - error int64 - fatal int64 -} - -func (log loggerStdData016) GetLogger() *eudore.LoggerStd { - entry := log.LoggerStdData.GetLogger() - _, ok := entry.LoggerStdData.(loggerStdData016) - if !ok { - entry.LoggerStdData = loggerStdData016{entry.LoggerStdData, log.meta} - } - return entry -} - -func (log loggerStdData016) PutLogger(entry *eudore.LoggerStd) { - switch entry.Level { - case eudore.LoggerDebug: - log.meta.debug++ - case eudore.LoggerInfo: - log.meta.info++ - case eudore.LoggerWarning: - log.meta.warning++ - case eudore.LoggerError: - log.meta.error++ - case eudore.LoggerFatal: - log.meta.fatal++ - } - log.LoggerStdData.PutLogger(entry) -} - -func (log loggerStdData016) Metadata() interface{} { - return map[string]interface{}{ - "name": "loggerStdData016", - "debug": log.meta.debug, - "info": log.meta.info, - "warning": log.meta.warning, - "ererror": log.meta.error, - "fatal": log.meta.fatal, - } } -func TestMetadata(t *testing.T) { - app := eudore.NewApp() - meta, ok := app.Logger.(interface{ Metadata() interface{} }) - if ok { - app.Infof("%#v", meta.Metadata()) - } +func TestLoggerMonkeyErr(t *testing.T) { + defer func() { + t.Logf("MonkeyErr recover %v", recover()) + }() - app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerStd(&loggerStdData016{eudore.NewLoggerStdDataJSON(nil), &loggerStdData016Meta{}})) - meta, ok = app.Logger.(interface{ Metadata() interface{} }) - if ok { - app.Infof("%#v", meta.Metadata()) - } + patchCallers := monkey.Patch(runtime.Callers, func(int, []uintptr) int { return 0 }) + defer patchCallers.Unpatch() + patchOpen := monkey.Patch(os.OpenFile, func(name string, flag int, perm os.FileMode) (*os.File, error) { + return nil, fmt.Errorf("monkey no open") + }) + defer patchOpen.Unpatch() - app.CancelFunc() - app.Run() + eudore.GetCallerStacks(0) + eudore.NewLogger(&eudore.LoggerConfig{ + Path: "app-yyyy-mm-dd.log", + }) } -func TestLoggerMonkey(t *testing.T) { - patch1 := monkey.Patch(runtime.Caller, func(int) (uintptr, string, int, bool) { return 0, "", 0, false }) - patch2 := monkey.Patch(runtime.Callers, func(int, []uintptr) int { return 0 }) - defer patch1.Unpatch() - defer patch2.Unpatch() - - log := eudore.NewLoggerStd(nil) - log.WithField("depth", "enable").Error(eudore.GetPanicStack(0)) - +func TestLoggerMonkeyTime(t *testing.T) { defer os.RemoveAll("logger") - log = eudore.NewLoggerStd(&eudore.LoggerStdConfig{ - Path: "logger/logger-yyyy-MM-dd-HH-index.log", - Link: "logger/logger.log", - MaxSize: 1 << 10, // 1k - Std: false, - Level: eudore.LoggerDebug, - TimeFormat: "Mon Jan 2 15:04:05 -0700 MST 2006", + log := eudore.NewLogger(&eudore.LoggerConfig{ + Path: "logger/app-yyyy-mm-dd-hh.log", + Link: "logger/app.log", + MaxSize: 1 << 10, // 1k + MaxCount: 10, }) // This is as unsafe as it sounds and I don't recommend anyone do it outside of a testing environment. @@ -408,18 +405,17 @@ func TestLoggerMonkey(t *testing.T) { } func BenchmarkLoggerStd(b *testing.B) { - data := map[string]interface{}{ + data := map[string]any{ "a": 1, "b": 2, } - log := eudore.NewLoggerStd(&eudore.LoggerStdConfig{ + log := eudore.NewLogger(&eudore.LoggerConfig{ Path: "t2.log", }) b.ReportAllocs() for i := 0; i < b.N; i++ { - log.WithFields([]string{"animal", "number", "size"}, []interface{}{"walrus", 1, 10}).Info("A walrus appears") + log.WithFields([]string{"animal", "number", "size"}, []any{"walrus", 1, 10}).Info("A walrus appears") log.WithField("a", 1).WithField("b", true).Info(data) } - log.Sync() os.Remove("t2.log") } diff --git a/_example/middleware2_test.go b/_example/middleware2_test.go new file mode 100644 index 0000000..9802731 --- /dev/null +++ b/_example/middleware2_test.go @@ -0,0 +1,701 @@ +package eudore_test + +import ( + "bytes" + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/eudore/eudore" + "github.com/eudore/eudore/middleware" +) + +func TestMiddlewareBlack(*testing.T) { + middleware.NewBlackFunc(map[string]bool{ + "192.168.0.0/16": true, + "0.0.0.0/0": false, + }, nil) + + app := eudore.NewApp() + app.AddMiddleware(middleware.NewBlackFunc(map[string]bool{ + "192.168.100.0/24": true, + "192.168.75.0/30": true, + "192.168.1.100/30": true, + "127.0.0.1/32": true, + "10.168.0.0/16": true, + "0.0.0.0/0": false, + }, app.Group("/eudore/debug"))) + app.AnyFunc("/*", eudore.HandlerEmpty) + + app.NewRequest(nil, "GET", "/eudore/debug/black/ui") + app.NewRequest(nil, "GET", "/eudore/debug/black/ui") + app.NewRequest(nil, "PUT", "/eudore/debug/black/black/10.127.87.0?mask=24") + app.NewRequest(nil, "PUT", "/eudore/debug/black/white/10.127.87.0?mask=24") + app.NewRequest(nil, "GET", "/eudore/debug/black/data") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/black/10.127.87.0?mask=24") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/10.127.87.0?mask=24") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/black/10.127.87.0?mask=24") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/10.127.87.0?mask=24") + + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"127.0.0.1:29398"}}, eudore.NewClientCheckStatus(403)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"127.0.0.1:29398"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.75.1:8298"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.100.3/28"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.100.0"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.100.1"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.100.77"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.100.148"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.100.222"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.75.4"}}, eudore.NewClientCheckStatus(403)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.75.5"}}, eudore.NewClientCheckStatus(403)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.75.6"}}, eudore.NewClientCheckStatus(403)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.1.99"}}, eudore.NewClientCheckStatus(403)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.1.100"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.1.101"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.1.102"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.1.103"}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.1.104"}}, eudore.NewClientCheckStatus(403)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.1.105"}}, eudore.NewClientCheckStatus(403)) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"127.0.0.1"}}) + app.NewRequest(nil, "GET", "/eudore", eudore.NewClientCheckStatus(403)) + + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/0.0.0.0?mask=0") + app.NewRequest(nil, "PUT", "/eudore/debug/black/white/192.168.75.4?mask=30") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/192.168.75.1") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/192.168.75.5") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/192.168.75.7") + app.NewRequest(nil, "PUT", "/eudore/debug/black/white/10.16.0.0?mask=16") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/192.168.75.4?mask=30") + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareBreaker(*testing.T) { + middleware.NewBreakerFunc(nil) + + app := eudore.NewApp() + // 创建熔断器并注入管理路由 + breaker := middleware.NewBreaker() + breaker.MaxConsecutiveSuccesses = 3 + breaker.MaxConsecutiveFailures = 3 + breaker.OpenWait = 0 + app.AddMiddleware(middleware.NewLoggerFunc(app, "route")) + app.AddMiddleware(breaker.NewBreakerFunc(app.Group("/eudore/debug"))) + app.AnyFunc("/*", func(ctx eudore.Context) { + if len(ctx.Querys()) > 0 { + ctx.Fatal("test err") + return + } + ctx.WriteString("route: " + ctx.GetParam("route")) + }) + + // 错误请求 + for i := 0; i < 10; i++ { + app.NewRequest(nil, "GET", "/1?a=1") + } + for i := 0; i < 5; i++ { + time.Sleep(time.Millisecond * 500) + app.NewRequest(nil, "GET", "/1?a=1") + } + // 除非熔断后访问 + for i := 0; i < 5; i++ { + time.Sleep(time.Millisecond * 500) + app.NewRequest(nil, "GET", "/1") + } + + app.NewRequest(nil, "GET", "/eudore/debug/breaker/ui") + app.NewRequest(nil, "GET", "/eudore/debug/breaker/ui") + app.NewRequest(nil, "GET", "/eudore/debug/breaker/data", http.Header{eudore.HeaderAccept: {eudore.MimeApplicationJSON}}) + app.NewRequest(nil, "GET", "/eudore/debug/breaker/1") + app.NewRequest(nil, "GET", "/eudore/debug/breaker/100") + app.NewRequest(nil, "PUT", "/eudore/debug/breaker/1/state/0") + app.NewRequest(nil, "PUT", "/eudore/debug/breaker/1/state/3") + app.NewRequest(nil, "PUT", "/eudore/debug/breaker/3/state/3") + + time.Sleep(time.Microsecond * 100) + app.CancelFunc() + app.Run() +} + +func TestMiddlewareCacheData(*testing.T) { + app := eudore.NewApp() + app.AddMiddleware("global", middleware.NewLoggerFunc(app, "route")) + app.AddMiddleware(middleware.NewCacheFunc(time.Second/100, app.Context, func(ctx eudore.Context) string { + // 自定义缓存key函数,默认实现方法 + if ctx.Method() != eudore.MethodGet || ctx.GetHeader(eudore.HeaderUpgrade) != "" { + return "" + } + return ctx.Request().URL.RequestURI() + })) + app.AnyFunc("/sf", func(ctx eudore.Context) { + ctx.Redirect(301, "/") + ctx.Debug(ctx.Response().Status(), ctx.Response().Size()) + }) + app.AnyFunc("/*", func(ctx eudore.Context) { + time.Sleep(time.Second / 200) + ctx.WriteString("hello eudore") + }) + + app.NewRequest(nil, "GET", "/sf") + wg := sync.WaitGroup{} + wg.Add(5) + for n := 0; n < 5; n++ { + go func() { + for i := 0; i < 5; i++ { + var o any + app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i), func(resp *http.Response) error { + if resp != nil { + o = eudore.NewClientOptionHeader(eudore.HeaderIfModifiedSince, + resp.Header.Get(eudore.HeaderLastModified), + ) + } + return nil + }) + app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i), o) + time.Sleep(time.Millisecond * 20) + app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i), o) + } + wg.Done() + }() + } + wg.Wait() + + app.NewRequest(nil, "GET", "/sf") + app.NewRequest(nil, "POST", "/sf") + time.Sleep(time.Millisecond * 220) + app.NewRequest(nil, "GET", "/s") + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareCacheStore(*testing.T) { + app := eudore.NewApp() + app.AddMiddleware("global", middleware.NewLoggerFunc(app, "route")) + app.AddMiddleware(middleware.NewCacheFunc(time.Second/100, app.Context, new(cacheMap))) + app.AnyFunc("/sf", func(ctx eudore.Context) { + ctx.Redirect(301, "/") + ctx.Debug(ctx.Response().Status(), ctx.Response().Size()) + }) + app.AnyFunc("/*", func(ctx eudore.Context) { + time.Sleep(time.Second / 200) + ctx.WriteString("hello eudore") + }) + + app.NewRequest(nil, "GET", "/sf") + wg := sync.WaitGroup{} + wg.Add(5) + for n := 0; n < 5; n++ { + go func() { + for i := 0; i < 3; i++ { + app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) + app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) + time.Sleep(time.Millisecond * 20) + app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) + } + wg.Done() + }() + } + wg.Wait() + + app.NewRequest(nil, "GET", "/sf") + app.NewRequest(nil, "POST", "/sf") + app.NewRequest(nil, "GET", "/s") + + app.CancelFunc() + app.Run() +} + +type cacheMap struct { + sync.Map +} + +func (m *cacheMap) Load(key string) *middleware.CacheData { + data, ok := m.Map.Load(key) + if !ok { + return nil + } + item := data.(*middleware.CacheData) + if time.Now().After(item.Expired) { + m.Map.Delete(key) + return nil + } + fmt.Println("cache", key) + return item +} + +func (m *cacheMap) Store(key string, val *middleware.CacheData) { + fmt.Println("new", key) + m.Map.Store(key, val) +} + +func TestMiddlewareRateRequest(*testing.T) { + app := eudore.NewApp() + app.AnyFunc("/*", middleware.NewRateRequestFunc(1, 3, app.Context), eudore.HandlerEmpty) + + for i := 0; i < 8; i++ { + app.NewRequest(nil, "GET", "/") + } + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareRateSpeed1(*testing.T) { + app := eudore.NewApp() + app.AddMiddleware(middleware.NewRateSpeedFunc(16*1024, 64*1024, app.Context)) + app.PostFunc("/post", func(ctx eudore.Context) { + ctx.Debug(string(ctx.Body())) + }) + app.AnyFunc("/srv", func(ctx eudore.Context) { + ctx.WriteString("rate speed 16kB") + }) + app.AnyFunc("/*", eudore.HandlerEmpty) + + app.NewRequest(nil, "POST", "/post", strings.NewReader("return body")) + app.NewRequest(nil, "PUT", "/srv") + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareRateSpeed2(*testing.T) { + app := eudore.NewApp() + app.AnyFunc("/*", middleware.NewRateRequestFunc(3, 1, app.Context, time.Millisecond*100), eudore.HandlerEmpty) + + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + time.Sleep(time.Second / 10) + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareRateSpeed3(*testing.T) { + app := eudore.NewApp() + app.AnyFunc("/*", middleware.NewRateRequestFunc(3, 1, app.Context, time.Microsecond*49), eudore.HandlerEmpty) + + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + time.Sleep(time.Second / 10) + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareRateSpeedCannel1(*testing.T) { + app := eudore.NewApp() + app.AddMiddleware("/out", func(ctx eudore.Context) { + c1 := ctx.GetContext() + c2, cannel := context.WithTimeout(context.Background(), time.Millisecond*20) + go func() { + cannel() + }() + ctx.SetContext(c2) + ctx.Next() + ctx.SetContext(c1) + }) + app.AddMiddleware(middleware.NewRateRequestFunc(1, 3, app.Context, time.Millisecond*10, func(ctx eudore.Context) string { + return ctx.RealIP() + })) + app.AnyFunc("/out", eudore.HandlerEmpty) + app.AnyFunc("/*", eudore.HandlerEmpty) + + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + time.Sleep(50 * time.Millisecond) + app.NewRequest(nil, "PUT", "/out") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/out") + app.NewRequest(nil, "PUT", "/") + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareRateSpeedCannel2(*testing.T) { + app := eudore.NewApp() + app.AddMiddleware("/out", func(ctx eudore.Context) { + c, cannel := context.WithTimeout(ctx.GetContext(), time.Millisecond*2) + cannel() + ctx.SetContext(c) + }) + app.AddMiddleware(middleware.NewRateRequestFunc(1, 3, app.Context, time.Millisecond*10, func(ctx eudore.Context) string { + return ctx.RealIP() + })) + app.AnyFunc("/out", func(ctx eudore.Context) { + time.Sleep(time.Millisecond * 5) + }) + app.AnyFunc("/*", eudore.HandlerEmpty) + + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/") + app.NewRequest(nil, "PUT", "/out") + app.NewRequest(nil, "PUT", "/out") + app.NewRequest(nil, "PUT", "/out") + app.NewRequest(nil, "PUT", "/out") + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareRateSpeedTimeout(*testing.T) { + app := eudore.NewApp() + app.SetHandler(http.TimeoutHandler(app, time.Second/4, "")) + + // 测试数据限速16B + app.AddMiddleware(middleware.NewRateSpeedFunc(160, 32, app.Context)) + app.AnyFunc("/bytes", func(ctx eudore.Context) { + for i := 0; i < 10; i++ { + _, err := ctx.Write([]byte("rate speed =16B\n")) + if err != nil { + return + } + } + }) + app.AnyFunc("/string", func(ctx eudore.Context) { + for i := 0; i < 10; i++ { + _, err := ctx.WriteString("rate speed =16B\n") + if err != nil { + return + } + } + }) + app.PostFunc("/post", func(ctx eudore.Context) { + ctx.Debug(string(ctx.Body())) + }) + app.AnyFunc("/*", eudore.HandlerEmpty) + + app.NewRequest(nil, "GET", "/bytes") + app.NewRequest(nil, "GET", "/string") + app.NewRequest(nil, "POST", "/post", strings.NewReader("read body is to long, body太长,会中间件超时无法完全读取。")) + + app.CancelFunc() + app.Run() +} + +/* +goos: linux +goarch: amd64 +BenchmarkMiddlewareBlackTree-2 1000000 1212 ns/op 0 B/op 0 allocs/op +BenchmarkMiddlewareBlackArray-2 1000000 1956 ns/op 0 B/op 0 allocs/op +BenchmarkMiddlewareBlackIp2intbit-2 1000000 1654 ns/op 320 B/op 5 allocs/op +BenchmarkMiddlewareBlackNetParse-2 1000000 1989 ns/op 360 B/op 20 allocs/op +PASS +ok command-line-arguments 6.919s +*/ + +var ips []string = []string{ + "10.0.0.0/4", "127.0.0.1/8", "192.168.1.0/24", "192.168.75.0/24", "192.168.100.0/24", +} + +var requests []uint64 = []uint64{ + 725415979, 2727437335, 889276411, 4005535794, 3864288534, 3906172701, 282878927, 1284469666, 730935782, 3371086418, + 1506312450, 1351422527, 1427742110, 1787801507, 2252116061, 229145224, 2463885032, 977944943, 3785363053, 3752670878, + 1109101831, 523139815, 2692892509, 822628332, 1521829731, 1137604504, 3946127316, 3492727158, 3701842868, 1345785201, + 2479587981, 1525387624, 2335875430, 2742578379, 842531784, 4164034788, 4067025409, 3579565778, 1135250289, 2272239320, + 2221887036, 47163049, 756685807, 3064055796, 2298095091, 3099116819, 4070972416, 1014033, 3023215026, 555430525, + 3702021454, 2340802113, 2507760403, 510831888, 3073321492, 4221140315, 1198583294, 1495418697, 827583711, 813333453, + 2746343126, 3755199452, 1697814659, 365059279, 3478405321, 2147566177, 281339662, 2742376600, 2293307920, 2061663865, + 913999062, 542572186, 4225265321, 633066366, 2063795404, 522841846, 195572401, 124532676, 2456662794, 3902204181, + 2491401143, 4233234751, 69766498, 388520887, 1017105985, 62871287, 3328355052, 1705168586, 2260082173, 3340006743, + 2211140888, 1906467873, 1247205260, 1492905294, 1014862918, 2587182986, 1040587870, 3570772999, 3084952258, 2425691705, +} + +var requeststrs []string = []string{ + "43.60.248.43", "162.145.100.23", "53.1.71.251", "238.191.160.50", "230.84.93.22", "232.211.119.29", "16.220.99.207", "76.143.115.162", "43.145.49.230", "200.238.178.82", + "89.200.129.2", "80.141.18.63", "85.25.157.158", "106.143.175.163", "134.60.144.93", "13.168.122.136", "146.219.230.232", "58.74.65.111", "225.160.14.109", "223.173.54.158", + "66.27.141.7", "31.46.122.231", "160.130.71.93", "49.8.79.236", "90.181.71.99", "67.206.119.152", "235.53.31.212", "208.46.201.118", "220.165.163.180", "80.55.13.113", + "147.203.130.141", "90.235.145.104", "139.58.161.102", "163.120.108.203", "50.56.3.200", "248.50.32.228", "242.105.226.1", "213.91.214.210", "67.170.139.113", "135.111.158.216", + "132.111.78.60", "2.207.166.169", "45.26.27.239", "182.161.199.244", "136.250.37.243", "184.184.197.19", "242.166.28.0", "0.15.121.17", "180.50.153.178", "33.27.50.125", + "220.168.93.78", "139.133.206.65", "149.121.99.19", "30.114.173.16", "183.47.42.20", "251.153.125.91", "71.112.237.254", "89.34.71.73", "49.83.236.223", "48.122.123.205", + "163.177.222.214", "223.211.203.220", "101.50.152.131", "21.194.92.207", "207.84.64.201", "128.1.66.97", "16.196.231.14", "163.117.88.152", "136.177.26.16", "122.226.126.121", + "54.122.132.214", "32.86.254.154", "251.216.110.169", "37.187.211.126", "123.3.4.204", "31.41.238.246", "11.168.50.177", "7.108.55.196", "146.109.179.10", "232.150.233.21", + "148.127.195.183", "252.82.9.63", "4.40.141.98", "23.40.91.183", "60.159.206.65", "3.191.86.247", "198.98.170.236", "101.162.206.202", "134.182.29.253", "199.20.117.87", + "131.203.85.24", "113.162.100.33", "74.86.215.140", "88.251.237.78", "60.125.148.70", "154.53.71.138", "62.6.28.94", "212.213.172.7", "183.224.162.194", "144.149.30.57", +} + +/* +func TestMiddlewareBlackResult(t *testing.T) { + tree := new(middleware.BlackNode) + array := new(BlackNodeArray) + for _, ip := range ips { + tree.Insert(ip) + array.Insert(ip) + } + for _, ip := range requests { + if tree.Look(ip) != array.Look(ip) { + t.Logf("tree: %t array: %t result not equal %d %s", tree.Look(ip), array.Look(ip), ip, int2ip(ip)) + } + } +} + +func BenchmarkMiddlewareBlackTree(b *testing.B) { + node := new(middleware.BlackNode) + for _, ip := range ips { + node.Insert(ip) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, ip := range requests { + node.Look(ip) + } + } +} +*/ + +func BenchmarkMiddlewareBlackArray(b *testing.B) { + node := new(BlackNodeArray) + b.ReportAllocs() + for _, ip := range ips { + node.Insert(ip) + } + for i := 0; i < b.N; i++ { + for _, ip := range requests { + node.Look(ip) + } + } +} + +func TestMiddlewareBlackParseip(t *testing.T) { + for _, ip := range ips { + ip1, bit1 := ip2intbit(ip) + ip2, bit2 := ip2netintbit(ip) + if ip1 != ip2 || bit1 != bit2 { + t.Log("ip parse error", ip, ip1, ip2, bit1, bit2) + } + } +} + +func BenchmarkMiddlewareBlackIp2intbit(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, ip := range ips { + ip2intbit(ip) + } + } +} + +func BenchmarkMiddlewareBlackNetParse(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, ip := range ips { + ip2netintbit(ip) + } + } +} + +// BlackNodeArray 定义数组遍历实现ip解析 +type BlackNodeArray struct { + Data []uint64 + Mask []uint + Count []uint64 +} + +// Insert 方法给黑名单节点新增一个ip或ip段。 +func (node *BlackNodeArray) Insert(ip string) { + iip, bit := ip2intbit(ip) + node.Data = append(node.Data, iip>>(32-bit)) + node.Mask = append(node.Mask, 32-bit) + node.Count = append(node.Count, 0) +} + +// Look 方法匹配ip是否在黑名单节点,命中则节点计数加一。 +func (node *BlackNodeArray) Look(ip uint64) bool { + for i := range node.Data { + if node.Data[i] == (ip >> node.Mask[i]) { + node.Count[i]++ + return true + } + } + return false +} + +// BlackNodeArrayNet 定义基于net库实现ip遍历匹配,支持ipv6. +type BlackNodeArrayNet struct { + Data []net.IP + Mask []net.IPMask + Count []uint64 +} + +// Insert 方法给黑名单节点新增一个ip或ip段。 +func (node *BlackNodeArrayNet) Insert(ip string) { + _, ipnet, _ := net.ParseCIDR(ip) + node.Data = append(node.Data, ipnet.IP) + node.Mask = append(node.Mask, ipnet.Mask) + node.Count = append(node.Count, 0) +} + +// Look 方法匹配ip是否在黑名单节点,命中则节点计数加一。 +func (node *BlackNodeArrayNet) Look(ip string) bool { + netip := net.ParseIP(ip) + for i := range node.Data { + if node.Data[i].Equal(netip.Mask(node.Mask[i])) { + node.Count[i]++ + return true + } + } + return false +} + +func ip2netintbit(ip string) (uint64, uint) { + ipaddr, ipnet, _ := net.ParseCIDR(ip) + length := len(ipaddr) + bit, _ := ipnet.Mask.Size() + var sum uint64 + sum += uint64(ipaddr[length-4]) << 24 + sum += uint64(ipaddr[length-3]) << 16 + sum += uint64(ipaddr[length-2]) << 8 + sum += uint64(ipaddr[length-1]) + return sum, uint(bit) +} + +func ip2intbit(ip string) (uint64, uint) { + bit := 32 + pos := strings.Index(ip, "/") + if pos != -1 { + bit, _ = strconv.Atoi(ip[pos+1:]) + ip = ip[:pos] + } + return ip2int(ip), uint(bit) +} + +func ip2int(ip string) uint64 { + bits := strings.Split(ip, ".") + b0, _ := strconv.Atoi(bits[0]) + b1, _ := strconv.Atoi(bits[1]) + b2, _ := strconv.Atoi(bits[2]) + b3, _ := strconv.Atoi(bits[3]) + + var sum uint64 + sum += uint64(b0) << 24 + sum += uint64(b1) << 16 + sum += uint64(b2) << 8 + sum += uint64(b3) + return sum +} + +func int2ip(ip uint64) string { + var bytes [4]uint64 + bytes[0] = ip & 0xFF + bytes[1] = (ip >> 8) & 0xFF + bytes[2] = (ip >> 16) & 0xFF + bytes[3] = (ip >> 24) & 0xFF + return fmt.Sprintf("%d.%d.%d.%d", bytes[3], bytes[2], bytes[1], bytes[0]) +} + +func BenchmarkMiddlewareRewrite(b *testing.B) { + rewritedata := map[string]string{ + "/js/*": "/public/js/$0", + "/api/v1/users/*/orders/*": "/api/v3/user/$0/order/$1", + "/d/*": "/d/$0-$0", + "/api/v1/*": "/api/v3/$0", + "/api/v2/*": "/api/v3/$0", + "/help/history*": "/api/v3/history", + "/help/history": "/api/v3/history", + "/help/*": "$0", + } + + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) + app.AddMiddleware("global", middleware.NewRewriteFunc(rewritedata)) + app.AnyFunc("/*", eudore.HandlerEmpty) + paths := []string{"/", "/js/", "/js/index.js", "/api/v1/user", "/api/v1/user/new", "/api/v1/users/v3/orders/8920", "/api/v1/users/orders", "/api/v2", "/api/v2/user", "/d/3", "/help/history", "/help/historyv2"} + w, r := httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, path := range paths { + r.URL.Path = path + app.ServeHTTP(w, r) + } + } +} + +func BenchmarkMiddlewareRewriteWithZero(b *testing.B) { + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) + app.AnyFunc("/*", eudore.HandlerEmpty) + paths := []string{"/", "/js/", "/js/index.js", "/api/v1/user", "/api/v1/user/new", "/api/v1/users/v3/orders/8920", "/api/v1/users/orders", "/api/v2", "/api/v2/user", "/d/3", "/help/history", "/help/historyv2"} + w, r := httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, path := range paths { + r.URL.Path = path + app.ServeHTTP(w, r) + } + } +} + +func BenchmarkMiddlewareRewriteWithRouter(b *testing.B) { + routerdata := map[string]interface{}{ + "/js/*0": newRewriteFunc("/public/js/$0"), + "/api/v1/users/:0/orders/*1": newRewriteFunc("/api/v3/user/$0/order/$1"), + "/d/*0": newRewriteFunc("/d/$0-$0"), + "/api/v1/*0": newRewriteFunc("/api/v3/$0"), + "/api/v2/*0": newRewriteFunc("/api/v3/$0"), + "/help/history*0": newRewriteFunc("/api/v3/history"), + "/help/history": newRewriteFunc("/api/v3/history"), + "/help/*0": newRewriteFunc("$0"), + } + app := eudore.NewApp() + app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) + app.AddMiddleware("global", middleware.NewRouterFunc(routerdata)) + app.AnyFunc("/*", eudore.HandlerEmpty) + paths := []string{"/", "/js/", "/js/index.js", "/api/v1/user", "/api/v1/user/new", "/api/v1/users/v3/orders/8920", "/api/v1/users/orders", "/api/v2", "/api/v2/user", "/d/3", "/help/history", "/help/historyv2"} + w, r := httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, path := range paths { + r.URL.Path = path + app.ServeHTTP(w, r) + } + } +} + +func newRewriteFunc(path string) eudore.HandlerFunc { + paths := strings.Split(path, "$") + Index := make([]string, 1, len(paths)*2-1) + Data := make([]string, 1, len(paths)*2-1) + Index[0] = "" + Data[0] = paths[0] + for _, path := range paths[1:] { + Index = append(Index, path[0:1]) + Data = append(Data, "") + if path[1:] != "" { + Index = append(Index, "") + Data = append(Data, path[1:]) + } + } + return func(ctx eudore.Context) { + buffer := bytes.NewBuffer(nil) + for i := range Index { + if Index[i] == "" { + buffer.WriteString(Data[i]) + } else { + buffer.WriteString(ctx.GetParam(Index[i])) + } + } + ctx.Request().URL.Path = buffer.String() + } +} diff --git a/_example/middleware3_test.go b/_example/middleware3_test.go new file mode 100644 index 0000000..708013e --- /dev/null +++ b/_example/middleware3_test.go @@ -0,0 +1,125 @@ +package eudore_test + +import ( + "net/http" + "net/url" + "strings" + "testing" + + "github.com/eudore/eudore" + "github.com/eudore/eudore/middleware" +) + +func TestMiddlewareNethttpBasicAuth(*testing.T) { + data := map[string]string{"user": "pw"} + + app := eudore.NewApp() + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) + app.SetHandler(middleware.NewNetHTTPBasicAuthFunc(mux, data)) + + app.NewRequest(nil, "GET", "/1") + app.NewRequest(nil, "GET", "/2", http.Header{"Authorization": {"Basic dXNlcjpwdw=="}}) + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareNethttpBodyLimit(*testing.T) { + app := eudore.NewApp() + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) + app.SetHandler(middleware.NewNetHTTPBodyLimitFunc(mux, 32)) + + app.NewRequest(nil, "GET", "/1") + app.NewRequest(nil, "GET", "/2", strings.NewReader("body")) + app.NewRequest(nil, "GET", "/3", strings.NewReader("1234567890abcdefghijklmnopqrstuvwxyz")) + app.NewRequest(nil, "GET", "/4", eudore.NewClientBodyForm(url.Values{"name": {"eudore"}})) + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareNethttpBlack(*testing.T) { + app := eudore.NewApp() + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) + app.SetHandler(middleware.NewNetHTTPBlackFunc(mux, map[string]bool{ + "127.0.0.1/8": true, + "192.168.0.0/16": true, + "10.0.0.0/8": false, + })) + + app.NewRequest(nil, "GET", "/eudore/debug/black/ui") + app.NewRequest(nil, "GET", "/eudore/debug/black/ui") + app.NewRequest(nil, "PUT", "/eudore/debug/black/black/10.127.87.0?mask=24") + app.NewRequest(nil, "PUT", "/eudore/debug/black/white/10.127.87.0?mask=24") + app.NewRequest(nil, "GET", "/eudore/debug/black/data") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/black/10.127.87.0?mask=24") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/10.127.87.0?mask=24") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/black/10.127.87.0?mask=24") + app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/10.127.87.0?mask=24") + + app.NewRequest(nil, "GET", "/eudore") + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXForwardedFor: {"192.168.1.4 192.168.1.1"}}) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"127.0.0.1:29398"}}) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"192.168.75.1:8298"}}) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"10.1.1.1:2334"}}) + app.NewRequest(nil, "GET", "/eudore", http.Header{eudore.HeaderXRealIP: {"172.17.1.1:2334"}}) + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareNethttpRateRequest(*testing.T) { + app := eudore.NewApp() + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) + app.SetHandler(middleware.NewNetHTTPRateRequestFunc(mux, 1, 3, func(req *http.Request) string { + // 自定义限流key + return req.UserAgent() + })) + + app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/") + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareNethttpRewrite(*testing.T) { + rewritedata := map[string]string{ + "/js/*": "/public/js/$0", + "/api/v1/users/*/orders/*": "/api/v3/user/$0/order/$1", + "/d/*": "/d/$0-$0", + "/api/v1/*": "/api/v3/$0", + "/api/v2/*": "/api/v3/$0", + "/help/history*": "/api/v3/history", + "/help/history": "/api/v3/history", + "/help/*": "$0", + } + + app := eudore.NewApp() + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) + app.SetHandler(middleware.NewNetHTTPRewriteFunc(mux, rewritedata)) + + app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/js/") + app.NewRequest(nil, "GET", "/js/index.js") + app.NewRequest(nil, "GET", "/api/v1/user") + app.NewRequest(nil, "GET", "/api/v1/user/new") + app.NewRequest(nil, "GET", "/api/v1/users/v3/orders/8920") + app.NewRequest(nil, "GET", "/api/v1/users/orders") + app.NewRequest(nil, "GET", "/api/v2") + app.NewRequest(nil, "GET", "/api/v2/user") + app.NewRequest(nil, "GET", "/d/3") + app.NewRequest(nil, "GET", "/help/history") + app.NewRequest(nil, "GET", "/help/historyv2") + + app.CancelFunc() + app.Run() +} diff --git a/_example/middleware_test.go b/_example/middleware_test.go index a7c7e66..e4538cc 100644 --- a/_example/middleware_test.go +++ b/_example/middleware_test.go @@ -2,17 +2,15 @@ package eudore_test import ( "bufio" - "bytes" + "compress/gzip" "context" "encoding/json" "fmt" "io" "net" "net/http" - "net/http/httptest" - "strconv" + "net/url" "strings" - "sync" "testing" "time" @@ -37,8 +35,8 @@ func TestMiddlewareBasicAuth(*testing.T) { app := eudore.NewApp() app.AddMiddleware("global", middleware.NewBasicAuthFunc(map[string]string{"eudore": "hello"})) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, "Basic ZXVkb3JlOmhlbGxv")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, "eudore")) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAuthorization: {"Basic ZXVkb3JlOmhlbGxv"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAuthorization: {"eudore"}}) app.CancelFunc() app.Run() @@ -46,17 +44,27 @@ func TestMiddlewareBasicAuth(*testing.T) { func TestMiddlewareBodyLimit(*testing.T) { app := eudore.NewApp() - app.AddMiddleware("global", middleware.NewBodyLimitFunc(32)) + app.AddMiddleware("global", + middleware.NewCompressGzipFunc(), + middleware.NewBodyLimitFunc(32), + ) app.AnyFunc("/", func(ctx eudore.Context) { ctx.Body() }) + app.AnyFunc("/form", func(ctx eudore.Context) { + ctx.FormValues() + }) - app.NewRequest(nil, "GET", "/", eudore.NewClientBodyString("123456")) - app.NewRequest(nil, "GET", "/", eudore.NewClientBodyString("1234567890abcdefghijklmnopqrstuvwxyz")) + app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/", strings.NewReader("123456")) + app.NewRequest(nil, "GET", "/", strings.NewReader("1234567890abcdefghijklmnopqrstuvwxyz")) // limit chunck - app.NewRequest(nil, "GET", "/", eudore.NewClientBodyFormValues(map[string]string{ - "name": "eudore", "value": "1234567890abcdefghijklmnopqrstuvwxyz", - })) + data := url.Values{ + "name": {"eudore"}, + "value": {"1234567890abcdefghijklmnopqrstuvwxyz"}, + } + app.NewRequest(nil, "GET", "/", eudore.NewClientBodyForm(data)) + app.NewRequest(nil, "GET", "/form", eudore.NewClientBodyForm(data)) app.CancelFunc() app.Run() @@ -100,7 +108,7 @@ func (ctx contextParams) GetParam(key string) string { func TestMiddlewareHeader(*testing.T) { app := eudore.NewApp() - app.AddMiddleware("global", middleware.NewHeaderWithSecureFunc(http.Header{"Server": []string{"eudore"}})) + app.AddMiddleware("global", middleware.NewHeaderWithSecureFunc(http.Header{"Server": {"eudore"}})) app.AddMiddleware("global", middleware.NewHeaderFunc(nil)) app.NewRequest(nil, "GET", "/") @@ -133,7 +141,7 @@ func TestMiddlewareLogger(*testing.T) { ctx.Fatal("test error") }) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderXForwardedFor, "172.17.0.1")) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderXForwardedFor: {"172.17.0.1"}}) app.NewRequest(nil, "POST", "/500") app.CancelFunc() @@ -191,217 +199,6 @@ func TestMiddlewareRequestID(*testing.T) { app.Run() } -func TestMiddlewareBlack(*testing.T) { - middleware.NewBlackFunc(map[string]bool{ - "192.168.0.0/16": true, - "0.0.0.0/0": false, - }, nil) - - app := eudore.NewApp() - app.AddMiddleware(middleware.NewBlackFunc(map[string]bool{ - "192.168.100.0/24": true, - "192.168.75.0/30": true, - "192.168.1.100/30": true, - "127.0.0.1/32": true, - "10.168.0.0/16": true, - "0.0.0.0/0": false, - }, app.Group("/eudore/debug"))) - app.AnyFunc("/*", eudore.HandlerEmpty) - - app.NewRequest(nil, "GET", "/eudore/debug/black/ui") - app.NewRequest(nil, "GET", "/eudore/debug/black/ui") - app.NewRequest(nil, "PUT", "/eudore/debug/black/black/10.127.87.0?mask=24") - app.NewRequest(nil, "PUT", "/eudore/debug/black/white/10.127.87.0?mask=24") - app.NewRequest(nil, "GET", "/eudore/debug/black/data") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/black/10.127.87.0?mask=24") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/10.127.87.0?mask=24") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/black/10.127.87.0?mask=24") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/10.127.87.0?mask=24") - - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "127.0.0.1:29398"), eudore.NewClientCheckStatus(403)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "127.0.0.1:29398"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.75.1:8298"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.100.3/28"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.100.0"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.100.1"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.100.77"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.100.148"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.100.222"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.75.4"), eudore.NewClientCheckStatus(403)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.75.5"), eudore.NewClientCheckStatus(403)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.75.6"), eudore.NewClientCheckStatus(403)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.1.99"), eudore.NewClientCheckStatus(403)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.1.100"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.1.101"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.1.102"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.1.103"), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.1.104"), eudore.NewClientCheckStatus(403)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.1.105"), eudore.NewClientCheckStatus(403)) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "127.0.0.1")) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientCheckStatus(403)) - - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/0.0.0.0?mask=0") - app.NewRequest(nil, "PUT", "/eudore/debug/black/white/192.168.75.4?mask=30") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/192.168.75.1") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/192.168.75.5") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/192.168.75.7") - app.NewRequest(nil, "PUT", "/eudore/debug/black/white/10.16.0.0?mask=16") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/192.168.75.4?mask=30") - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareBreaker(*testing.T) { - middleware.NewBreakerFunc(nil) - - app := eudore.NewApp() - // 创建熔断器并注入管理路由 - breaker := middleware.NewBreaker() - breaker.MaxConsecutiveSuccesses = 3 - breaker.MaxConsecutiveFailures = 3 - breaker.OpenWait = 0 - app.AddMiddleware(middleware.NewLoggerFunc(app, "route")) - app.AddMiddleware(breaker.NewBreakerFunc(app.Group("/eudore/debug"))) - app.AnyFunc("/*", func(ctx eudore.Context) { - if len(ctx.Querys()) > 0 { - ctx.Fatal("test err") - return - } - ctx.WriteString("route: " + ctx.GetParam("route")) - }) - - // 错误请求 - for i := 0; i < 10; i++ { - app.NewRequest(nil, "GET", "/1?a=1") - } - for i := 0; i < 5; i++ { - time.Sleep(time.Millisecond * 500) - app.NewRequest(nil, "GET", "/1?a=1") - } - // 除非熔断后访问 - for i := 0; i < 5; i++ { - time.Sleep(time.Millisecond * 500) - app.NewRequest(nil, "GET", "/1") - } - - app.NewRequest(nil, "GET", "/eudore/debug/breaker/ui") - app.NewRequest(nil, "GET", "/eudore/debug/breaker/ui") - app.NewRequest(nil, "GET", "/eudore/debug/breaker/data", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) - app.NewRequest(nil, "GET", "/eudore/debug/breaker/1") - app.NewRequest(nil, "GET", "/eudore/debug/breaker/100") - app.NewRequest(nil, "PUT", "/eudore/debug/breaker/1/state/0") - app.NewRequest(nil, "PUT", "/eudore/debug/breaker/1/state/3") - app.NewRequest(nil, "PUT", "/eudore/debug/breaker/3/state/3") - - time.Sleep(time.Microsecond * 100) - app.CancelFunc() - app.Run() -} - -func TestMiddlewareCache(*testing.T) { - app := eudore.NewApp() - app.AddMiddleware("global", middleware.NewLoggerFunc(app, "route")) - app.AddMiddleware(middleware.NewCacheFunc(time.Second/10, app.Context, func(ctx eudore.Context) string { - // 自定义缓存key函数,默认实现方法 - if ctx.Method() != eudore.MethodGet || ctx.GetHeader(eudore.HeaderUpgrade) != "" { - return "" - } - return ctx.Request().URL.RequestURI() - })) - app.AnyFunc("/sf", func(ctx eudore.Context) { - ctx.Redirect(301, "/") - ctx.Debug(ctx.Response().Status(), ctx.Response().Size()) - }) - app.AnyFunc("/*", func(ctx eudore.Context) { - time.Sleep(time.Second / 3) - ctx.WriteString("hello eudore") - }) - - app.NewRequest(nil, "GET", "/sf") - wg := sync.WaitGroup{} - wg.Add(5) - for n := 0; n < 5; n++ { - go func() { - for i := 0; i < 3; i++ { - app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) - app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) - time.Sleep(time.Millisecond * 200) - app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) - } - wg.Done() - }() - } - wg.Wait() - - app.NewRequest(nil, "GET", "/sf") - app.NewRequest(nil, "POST", "/sf") - app.NewRequest(nil, "GET", "/s") - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareCacheStore(*testing.T) { - app := eudore.NewApp() - app.AddMiddleware("global", middleware.NewLoggerFunc(app, "route")) - app.AddMiddleware(middleware.NewCacheFunc(time.Second/100, app.Context, new(cacheMap))) - app.AnyFunc("/sf", func(ctx eudore.Context) { - ctx.Redirect(301, "/") - ctx.Debug(ctx.Response().Status(), ctx.Response().Size()) - }) - app.AnyFunc("/*", func(ctx eudore.Context) { - time.Sleep(time.Second / 3) - ctx.WriteString("hello eudore") - }) - - app.NewRequest(nil, "GET", "/sf") - wg := sync.WaitGroup{} - wg.Add(5) - for n := 0; n < 5; n++ { - go func() { - for i := 0; i < 3; i++ { - app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) - app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) - time.Sleep(time.Millisecond * 20) - app.NewRequest(nil, "GET", "/?c="+fmt.Sprint(i)) - } - wg.Done() - }() - } - wg.Wait() - - app.NewRequest(nil, "GET", "/sf") - app.NewRequest(nil, "POST", "/sf") - app.NewRequest(nil, "GET", "/s") - - app.CancelFunc() - app.Run() -} - -type cacheMap struct { - sync.Map -} - -func (m *cacheMap) Load(key string) *middleware.CacheData { - data, ok := m.Map.Load(key) - if !ok { - return nil - } - item := data.(*middleware.CacheData) - if time.Now().After(item.Expired) { - m.Map.Delete(key) - return nil - } - fmt.Println("cache", key) - return item -} - -func (m *cacheMap) Store(key string, val *middleware.CacheData) { - fmt.Println("new", key) - m.Map.Store(key, val) -} - func TestMiddlewareCors(*testing.T) { middleware.NewCorsFunc(nil, map[string]string{ "Access-Control-Allow-Credentials": "true", @@ -420,19 +217,19 @@ func TestMiddlewareCors(*testing.T) { })) app.NewRequest(nil, "OPTIONS", "/1") - app.NewRequest(nil, "OPTIONS", "/2", eudore.NewClientHeader("Origin", eudore.DefaultClientInternalHost)) - app.NewRequest(nil, "OPTIONS", "/3", eudore.NewClientHeader("Origin", "http://localhost")) - app.NewRequest(nil, "OPTIONS", "/4", eudore.NewClientHeader("Origin", "http://127.0.0.1:8088")) - app.NewRequest(nil, "OPTIONS", "/5", eudore.NewClientHeader("Origin", "http://127.0.0.1:8089")) - app.NewRequest(nil, "OPTIONS", "/6", eudore.NewClientHeader("Origin", "http://example.com")) - app.NewRequest(nil, "OPTIONS", "/6", eudore.NewClientHeader("Origin", "http://www.eudore.cn")) + app.NewRequest(nil, "OPTIONS", "/2", http.Header{eudore.HeaderOrigin: {eudore.DefaultClientInternalHost}}) + app.NewRequest(nil, "OPTIONS", "/3", http.Header{eudore.HeaderOrigin: {"http://localhost"}}) + app.NewRequest(nil, "OPTIONS", "/4", http.Header{eudore.HeaderOrigin: {"http://127.0.0.1:8088"}}) + app.NewRequest(nil, "OPTIONS", "/5", http.Header{eudore.HeaderOrigin: {"http://127.0.0.1:8089"}}) + app.NewRequest(nil, "OPTIONS", "/6", http.Header{eudore.HeaderOrigin: {"http://example.com"}}) + app.NewRequest(nil, "OPTIONS", "/6", http.Header{eudore.HeaderOrigin: {"http://www.eudore.cn"}}) app.NewRequest(nil, "GET", "/1") - app.NewRequest(nil, "GET", "/2", eudore.NewClientHeader("Origin", eudore.DefaultClientHost)) - app.NewRequest(nil, "GET", "/3", eudore.NewClientHeader("Origin", "http://localhost")) - app.NewRequest(nil, "GET", "/4", eudore.NewClientHeader("Origin", "http://127.0.0.1:8088")) - app.NewRequest(nil, "GET", "/5", eudore.NewClientHeader("Origin", "http://127.0.0.1:8089")) - app.NewRequest(nil, "GET", "/6", eudore.NewClientHeader("Origin", "http://example.com")) - app.NewRequest(nil, "GET", "/6", eudore.NewClientHeader("Origin", "http://www.eudore.cn")) + app.NewRequest(nil, "GET", "/2", http.Header{eudore.HeaderOrigin: {eudore.DefaultClientHost}}) + app.NewRequest(nil, "GET", "/3", http.Header{eudore.HeaderOrigin: {"http://localhost"}}) + app.NewRequest(nil, "GET", "/4", http.Header{eudore.HeaderOrigin: {"http://127.0.0.1:8088"}}) + app.NewRequest(nil, "GET", "/5", http.Header{eudore.HeaderOrigin: {"http://127.0.0.1:8089"}}) + app.NewRequest(nil, "GET", "/6", http.Header{eudore.HeaderOrigin: {"http://example.com"}}) + app.NewRequest(nil, "GET", "/6", http.Header{eudore.HeaderOrigin: {"http://www.eudore.cn"}}) app.CancelFunc() app.Run() @@ -441,8 +238,8 @@ func TestMiddlewareCors(*testing.T) { func TestMiddlewareCsrf(*testing.T) { app := eudore.NewApp() app.AnyFunc("/query", middleware.NewCsrfFunc("query: csrf", "_csrf"), eudore.HandlerEmpty) - app.AnyFunc("/header", middleware.NewCsrfFunc("header: "+eudore.HeaderXCSRFToken, eudore.SetCookie{Name: "_csrf", MaxAge: 86400}), eudore.HandlerEmpty) - app.AnyFunc("/form", middleware.NewCsrfFunc("form: csrf", &eudore.SetCookie{Name: "_csrf", MaxAge: 86400}), eudore.HandlerEmpty) + app.AnyFunc("/header", middleware.NewCsrfFunc("header: "+eudore.HeaderXCSRFToken, eudore.CookieSet{Name: "_csrf", MaxAge: 86400}), eudore.HandlerEmpty) + app.AnyFunc("/form", middleware.NewCsrfFunc("form: csrf", &eudore.CookieSet{Name: "_csrf", MaxAge: 86400}), eudore.HandlerEmpty) app.AnyFunc("/fn", middleware.NewCsrfFunc(func(ctx eudore.Context) string { return ctx.GetQuery("csrf") }, "_csrf"), eudore.HandlerEmpty) app.AnyFunc("/*", middleware.NewCsrfFunc(nil, nil), eudore.HandlerEmpty) @@ -456,13 +253,13 @@ func TestMiddlewareCsrf(*testing.T) { }, ) app.NewRequest(nil, "POST", "/2", eudore.NewClientCheckStatus(400)) - app.NewRequest(nil, "POST", "/1", eudore.NewClientQuery("csrf", csrfval), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "POST", "/query", eudore.NewClientQuery("csrf", csrfval), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "POST", "/header", eudore.NewClientHeader(eudore.HeaderXCSRFToken, csrfval), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "POST", "/form", eudore.NewClientBodyFormValue("csrf", csrfval), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "POST", "/form", eudore.NewClientBodyJSONValue("csrf", csrfval), eudore.NewClientCheckStatus(400)) - app.NewRequest(nil, "POST", "/fn", eudore.NewClientQuery("csrf", csrfval), eudore.NewClientCheckStatus(200)) - app.NewRequest(nil, "POST", "/nil", eudore.NewClientQuery("csrf", csrfval), eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "POST", "/1", url.Values{"csrf": {csrfval}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "POST", "/query", url.Values{"csrf": {csrfval}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "POST", "/header", http.Header{eudore.HeaderXCSRFToken: {csrfval}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "POST", "/form", eudore.NewClientBodyForm(url.Values{"csrf": {csrfval}}), eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "POST", "/form", eudore.NewClientBodyJSON(map[string]any{"csrf": csrfval}), eudore.NewClientCheckStatus(400)) + app.NewRequest(nil, "POST", "/fn", url.Values{"csrf": {csrfval}}, eudore.NewClientCheckStatus(200)) + app.NewRequest(nil, "POST", "/nil", url.Values{"csrf": {csrfval}}, eudore.NewClientCheckStatus(200)) app.CancelFunc() app.Run() @@ -514,10 +311,15 @@ func TestMiddlewareDump(*testing.T) { } }) app.AddMiddleware(middleware.NewDumpFunc(app.Group("/eudore/debug"))) - app.AnyFunc("/gzip", middleware.NewCompressGzipFunc(5), func(ctx eudore.Context) { + app.AnyFunc("/gzip", middleware.NewCompressGzipFunc(), func(ctx eudore.Context) { + ctx.WriteString("gzip body") + }) + app.AnyFunc("/gziperr1", func(ctx eudore.Context) { + ctx.SetHeader(eudore.HeaderContentEncoding, "gzip") + ctx.Write([]byte("gzip body")) ctx.WriteString("gzip body") }) - app.AnyFunc("/gziperr", func(ctx eudore.Context) { + app.AnyFunc("/gziperr2", func(ctx eudore.Context) { ctx.SetHeader(eudore.HeaderContentEncoding, "gzip") ctx.WriteString("gzip body") }) @@ -537,8 +339,9 @@ func TestMiddlewareDump(*testing.T) { time.Sleep(200 * time.Millisecond) app.NewRequest(nil, "GET", "http://localhost:8088/eudore/debug/dump/connect") - app.NewRequest(nil, "GET", "/gzip", eudore.NewClientHeader(eudore.HeaderAcceptEncoding, "gzip")) - app.NewRequest(nil, "GET", "/gziperr") + app.NewRequest(nil, "GET", "/gzip", http.Header{eudore.HeaderAcceptEncoding: {"gzip"}}) + app.NewRequest(nil, "GET", "/gziperr1") + app.NewRequest(nil, "GET", "/gziperr2") app.NewRequest(nil, "GET", "/echo") app.NewRequest(nil, "GET", "/bigbody", func(resp *http.Response) { io.Copy(io.Discard, resp.Body) @@ -557,23 +360,72 @@ func (nodumpResponse022) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, fmt.Errorf("nodump") } -func TestMiddlewareGzip(*testing.T) { +func TestMiddlewareCompressGzip(*testing.T) { app := eudore.NewApp() - app.AddMiddleware(middleware.NewCompressDeflateFunc(100)) - app.AddMiddleware(middleware.NewCompressGzipFunc(10)) + app.AddMiddleware(middleware.NewCompressDeflateFunc()) + app.AddMiddleware(middleware.NewCompressGzipFunc()) app.AnyFunc("/*", func(ctx eudore.Context) { ctx.Debugf("%#v", ctx.Request().Header) - ctx.Push("/stat", nil) + ctx.WriteHeader(eudore.StatusOK) ctx.Response().Push("/stat", nil) ctx.Response().Push("/stat", &http.PushOptions{}) ctx.Response().Push("/stat", &http.PushOptions{Header: make(http.Header)}) - ctx.WriteString("gzip") + ctx.WriteString("compress") + }) + app.AnyFunc("/gzip", func(ctx eudore.Context) { + ctx.SetHeader(eudore.HeaderContentType, "application/gzip;encoding=gzip") + for i := 0; i < 20; i++ { + ctx.WriteString("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXY") + } + ctx.Response().Flush() + }) + app.AnyFunc("/long", func(ctx eudore.Context) { + data := []byte("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXY") + for i := 0; i < 20; i++ { + ctx.Write(data) + } + ctx.Response().Flush() + }) + app.AnyFunc("/longs", func(ctx eudore.Context) { + for i := 0; i < 20; i++ { + ctx.WriteString("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXY") + } ctx.Response().Flush() }) - app.NewRequest(nil, "GET", "/1") - app.NewRequest(nil, "GET", "/1", eudore.NewClientHeader(eudore.HeaderAcceptEncoding, "deflate")) - app.NewRequest(nil, "GET", "/1", eudore.NewClientHeader(eudore.HeaderAcceptEncoding, "none")) + app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAcceptEncoding: {middleware.CompressNameDeflate}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAcceptEncoding: {middleware.CompressNameGzip}}) + app.NewRequest(nil, "GET", "/gzip", http.Header{eudore.HeaderAcceptEncoding: {middleware.CompressNameGzip}}) + app.NewRequest(nil, "GET", "/long", http.Header{eudore.HeaderAcceptEncoding: {middleware.CompressNameGzip}}) + app.NewRequest(nil, "GET", "/longs", http.Header{eudore.HeaderAcceptEncoding: {middleware.CompressNameGzip}}) + + app.CancelFunc() + app.Run() +} + +func TestMiddlewareCompressMixins(*testing.T) { + middleware.DefaultComoressBrotliFunc = func() interface{} { + return gzip.NewWriter(io.Discard) + } + defer func() { + middleware.DefaultComoressBrotliFunc = nil + }() + + app := eudore.NewApp() + app.AddMiddleware(middleware.NewCompressMixinsFunc(nil)) + app.AnyFunc("/*", func(ctx eudore.Context) { + ctx.Debugf("%#v", ctx.Request().Header) + ctx.WriteString("mixins") + ctx.Response().Flush() + }) + + app.NewRequest(nil, "GET", "/") + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAcceptEncoding: {middleware.CompressNameGzip}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAcceptEncoding: {middleware.CompressNameDeflate}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAcceptEncoding: {middleware.CompressNameIdentity}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAcceptEncoding: {"gzip;q=0"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAcceptEncoding: {"none"}}) app.CancelFunc() app.Run() @@ -610,9 +462,9 @@ func TestMiddlewareLook(*testing.T) { app.NewRequest(nil, "GET", "/eudore/debug/look/?format=json") app.NewRequest(nil, "GET", "/eudore/debug/look/?format=t2") app.NewRequest(nil, "GET", "/eudore/debug/look/Config/Keys/2") - app.NewRequest(nil, "GET", "/eudore/debug/look/?d=3", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) - app.NewRequest(nil, "GET", "/eudore/debug/look/?d=3", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeTextHTML)) - app.NewRequest(nil, "GET", "/eudore/debug/look/?d=3", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeText)) + app.NewRequest(nil, "GET", "/eudore/debug/look/?d=3", http.Header{eudore.HeaderAccept: {eudore.MimeApplicationJSON}}) + app.NewRequest(nil, "GET", "/eudore/debug/look/?d=3", http.Header{eudore.HeaderAccept: {eudore.MimeTextHTML}}) + app.NewRequest(nil, "GET", "/eudore/debug/look/?d=3", http.Header{eudore.HeaderAccept: {eudore.MimeText}}) app.CancelFunc() app.Run() @@ -631,10 +483,10 @@ func TestMiddlewareLookRender(*testing.T) { "date": time.Now(), } }) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAccept, middleware.MimeValueJSON)) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAccept, middleware.MimeValueJSON+","+eudore.MimeApplicationJSON)) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAccept, middleware.MimeValueHTML)) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAccept, middleware.MimeValueText)) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAccept: {middleware.MimeValueJSON}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAccept: {middleware.MimeValueJSON + "," + eudore.MimeApplicationJSON}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAccept: {middleware.MimeValueHTML}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderAccept: {middleware.MimeValueText}}) // time.Sleep(100* time.Microsecond) app.CancelFunc() @@ -643,12 +495,16 @@ func TestMiddlewareLookRender(*testing.T) { func TestMiddlewarePprof(*testing.T) { app := eudore.NewApp() - app.Group("/eudore/debug").AddController(middleware.NewPprofController()) + app.AnyFunc("/eudore/debug/pprof/*", middleware.HandlerPprof) - app.NewRequest(nil, "GET", "/eudore/debug/pprof/expvar", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) + app.NewRequest(nil, "GET", "/eudore/debug/pprof/expvar", http.Header{eudore.HeaderAccept: {eudore.MimeApplicationJSON}}) app.NewRequest(nil, "GET", "/eudore/debug/pprof/?format=json") app.NewRequest(nil, "GET", "/eudore/debug/pprof/?format=text") app.NewRequest(nil, "GET", "/eudore/debug/pprof/?format=html") + app.NewRequest(nil, "GET", "/eudore/debug/pprof/allocs") + app.NewRequest(nil, "GET", "/eudore/debug/pprof/block") + app.NewRequest(nil, "GET", "/eudore/debug/pprof/heap") + app.NewRequest(nil, "GET", "/eudore/debug/pprof/mutex") app.NewRequest(nil, "GET", "/eudore/debug/pprof/goroutine?debug=0") app.NewRequest(nil, "GET", "/eudore/debug/pprof/goroutine?debug=1") app.NewRequest(nil, "GET", "/eudore/debug/pprof/goroutine?debug=1&format=json") @@ -661,168 +517,11 @@ func TestMiddlewarePprof(*testing.T) { app.CancelFunc() app.Run() } - -func TestMiddlewareRateRequest(*testing.T) { - app := eudore.NewApp() - app.AnyFunc("/*", middleware.NewRateRequestFunc(1, 3, app.Context), eudore.HandlerEmpty) - - for i := 0; i < 8; i++ { - app.NewRequest(nil, "GET", "/") - } - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareRateSpeed1(*testing.T) { - app := eudore.NewApp() - app.AddMiddleware(middleware.NewRateSpeedFunc(16*1024, 64*1024, app.Context)) - app.PostFunc("/post", func(ctx eudore.Context) { - ctx.Debug(string(ctx.Body())) - }) - app.AnyFunc("/srv", func(ctx eudore.Context) { - ctx.WriteString("rate speed 16kB") - }) - app.AnyFunc("/*", eudore.HandlerEmpty) - - app.NewRequest(nil, "POST", "/post", eudore.NewClientBodyString("return body")) - app.NewRequest(nil, "PUT", "/srv") - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareRateSpeed2(*testing.T) { - app := eudore.NewApp() - app.AnyFunc("/*", middleware.NewRateRequestFunc(1, 3, app.Context, time.Millisecond*100), eudore.HandlerEmpty) - - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - time.Sleep(time.Second / 2) - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareRateSpeed3(*testing.T) { - app := eudore.NewApp() - app.AnyFunc("/*", middleware.NewRateRequestFunc(1, 2, app.Context, time.Microsecond*49), eudore.HandlerEmpty) - - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - time.Sleep(time.Second) - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareRateSpeedCannel1(*testing.T) { - app := eudore.NewApp() - app.AddMiddleware("/out", func(ctx eudore.Context) { - c1 := ctx.GetContext() - c2, cannel := context.WithTimeout(context.Background(), time.Millisecond*20) - go func() { - cannel() - }() - ctx.SetContext(c2) - ctx.Next() - ctx.SetContext(c1) - }) - app.AddMiddleware(middleware.NewRateRequestFunc(1, 3, app.Context, time.Millisecond*10, func(ctx eudore.Context) string { - return ctx.RealIP() - })) - app.AnyFunc("/out", eudore.HandlerEmpty) - app.AnyFunc("/*", eudore.HandlerEmpty) - - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - time.Sleep(50 * time.Millisecond) - app.NewRequest(nil, "PUT", "/out") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/out") - app.NewRequest(nil, "PUT", "/") - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareRateSpeedCannel2(*testing.T) { - app := eudore.NewApp() - app.AddMiddleware("/out", func(ctx eudore.Context) { - c, cannel := context.WithTimeout(ctx.GetContext(), time.Millisecond*2) - cannel() - ctx.SetContext(c) - }) - app.AddMiddleware(middleware.NewRateRequestFunc(1, 3, app.Context, time.Millisecond*10, func(ctx eudore.Context) string { - return ctx.RealIP() - })) - app.AnyFunc("/out", func(ctx eudore.Context) { - time.Sleep(time.Millisecond * 5) - }) - app.AnyFunc("/*", eudore.HandlerEmpty) - - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/") - app.NewRequest(nil, "PUT", "/out") - app.NewRequest(nil, "PUT", "/out") - app.NewRequest(nil, "PUT", "/out") - app.NewRequest(nil, "PUT", "/out") - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareRateSpeedTimeout(*testing.T) { - app := eudore.NewApp() - app.SetHandler(http.TimeoutHandler(app, 2*time.Second, "")) - - // /done限速512B - app.PostFunc("/done", func(ctx eudore.Context) { - c, cannel := context.WithCancel(ctx.GetContext()) - ctx.SetContext(c) - cannel() - }, middleware.NewRateSpeedFunc(512, 1024, app.Context), func(ctx eudore.Context) { - ctx.Debug(string(ctx.Body())) - }) - - // 测试数据限速16B - app.AddMiddleware(middleware.NewRateSpeedFunc(16, 128, app.Context)) - app.AnyFunc("/get", func(ctx eudore.Context) { - for i := 0; i < 10; i++ { - ctx.WriteString("rate speed =16B\n") - } - }) - app.PostFunc("/post", func(ctx eudore.Context) { - ctx.Debug(string(ctx.Body())) - }) - app.AnyFunc("/*", eudore.HandlerEmpty) - - app.NewRequest(nil, "GET", "/get") - app.NewRequest(nil, "POST", "/post", eudore.NewClientBodyString("read body is to long,body太大,会中间件超时无法完全读取。")) - app.NewRequest(nil, "POST", "/done", eudore.NewClientBodyString("hello")) - - app.CancelFunc() - app.Run() -} - func TestMiddlewareReferer(*testing.T) { app := eudore.NewApp() app.AddMiddleware(middleware.NewRefererFunc(map[string]bool{ "": true, - "origin": false, + "origin": true, "www.eudore.cn/*": true, "www.eudore.cn/api/*": false, "www.example.com/*": true, @@ -837,15 +536,19 @@ func TestMiddlewareReferer(*testing.T) { })) app.AnyFunc("/*", eudore.HandlerEmpty) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderReferer, "")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("www.eudore.cn"), eudore.NewClientHeader(eudore.HeaderReferer, "http://www.eudore.cn/")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderReferer, "http://www.eudore.cn/")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderReferer, "http://www.example.com")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderReferer, "http://www.example.com/")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderReferer, "http://www.example.com/1")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderReferer, "http://www.example.com/1/1")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderReferer, "http://www.example.com/1/2")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderReferer, "http://127.0.0.1/1")) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {""}}) + app.NewRequest(nil, "GET", "/", + eudore.NewClientOptionHost("www.eudore.cn"), + eudore.NewClientOptionHeader(eudore.HeaderReferer, "http://www.eudore.cn/"), + ) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {"http://www.eudore.cn/"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {"http://www.example.com"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {"http://www.example.com/"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {"http://www.example.com/1"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {"http://www.example.com/1/1"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {"http://www.example.com/1/2"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {"http://127.0.0.1:80/1"}}) + app.NewRequest(nil, "GET", "/", http.Header{eudore.HeaderReferer: {"http://127.0.0.10:80"}}) app.CancelFunc() app.Run() @@ -946,406 +649,3 @@ func TestMiddlewareRrouterRewrite(*testing.T) { app.CancelFunc() app.Run() } - -func TestMiddlewareNethttpBasicAuth(*testing.T) { - data := map[string]string{"user": "pw"} - - app := eudore.NewApp() - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) - app.SetHandler(middleware.NewNetHTTPBasicAuthFunc(mux, data)) - - app.NewRequest(nil, "GET", "/1") - app.NewRequest(nil, "GET", "/2", eudore.NewClientHeader("Authorization", "Basic dXNlcjpwdw==")) - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareNethttpBlack(*testing.T) { - app := eudore.NewApp() - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) - app.SetHandler(middleware.NewNetHTTPBlackFunc(mux, map[string]bool{ - "127.0.0.1/8": true, - "192.168.0.0/16": true, - "10.0.0.0/8": false, - })) - - app.NewRequest(nil, "GET", "/eudore/debug/black/ui") - app.NewRequest(nil, "GET", "/eudore/debug/black/ui") - app.NewRequest(nil, "PUT", "/eudore/debug/black/black/10.127.87.0?mask=24") - app.NewRequest(nil, "PUT", "/eudore/debug/black/white/10.127.87.0?mask=24") - app.NewRequest(nil, "GET", "/eudore/debug/black/data") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/black/10.127.87.0?mask=24") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/10.127.87.0?mask=24") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/black/10.127.87.0?mask=24") - app.NewRequest(nil, "DELETE", "/eudore/debug/black/white/10.127.87.0?mask=24") - - app.NewRequest(nil, "GET", "/eudore") - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXForwardedFor, "192.168.1.4 192.168.1.1")) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "127.0.0.1:29398")) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "192.168.75.1:8298")) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "10.1.1.1:2334")) - app.NewRequest(nil, "GET", "/eudore", eudore.NewClientHeader(eudore.HeaderXRealIP, "172.17.1.1:2334")) - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareNethttpRateRequest(*testing.T) { - app := eudore.NewApp() - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) - app.SetHandler(middleware.NewNetHTTPRateRequestFunc(mux, 1, 3, func(req *http.Request) string { - // 自定义限流key - return req.UserAgent() - })) - - app.NewRequest(nil, "GET", "/") - app.NewRequest(nil, "GET", "/") - app.NewRequest(nil, "GET", "/") - app.NewRequest(nil, "GET", "/") - app.NewRequest(nil, "GET", "/") - - app.CancelFunc() - app.Run() -} - -func TestMiddlewareNethttpRewrite(*testing.T) { - rewritedata := map[string]string{ - "/js/*": "/public/js/$0", - "/api/v1/users/*/orders/*": "/api/v3/user/$0/order/$1", - "/d/*": "/d/$0-$0", - "/api/v1/*": "/api/v3/$0", - "/api/v2/*": "/api/v3/$0", - "/help/history*": "/api/v3/history", - "/help/history": "/api/v3/history", - "/help/*": "$0", - } - - app := eudore.NewApp() - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {}) - app.SetHandler(middleware.NewNetHTTPRewriteFunc(mux, rewritedata)) - - app.NewRequest(nil, "GET", "/") - app.NewRequest(nil, "GET", "/js/") - app.NewRequest(nil, "GET", "/js/index.js") - app.NewRequest(nil, "GET", "/api/v1/user") - app.NewRequest(nil, "GET", "/api/v1/user/new") - app.NewRequest(nil, "GET", "/api/v1/users/v3/orders/8920") - app.NewRequest(nil, "GET", "/api/v1/users/orders") - app.NewRequest(nil, "GET", "/api/v2") - app.NewRequest(nil, "GET", "/api/v2/user") - app.NewRequest(nil, "GET", "/d/3") - app.NewRequest(nil, "GET", "/help/history") - app.NewRequest(nil, "GET", "/help/historyv2") - - app.CancelFunc() - app.Run() -} - -/* -goos: linux -goarch: amd64 -BenchmarkMiddlewareBlackTree-2 1000000 1212 ns/op 0 B/op 0 allocs/op -BenchmarkMiddlewareBlackArray-2 1000000 1956 ns/op 0 B/op 0 allocs/op -BenchmarkMiddlewareBlackIp2intbit-2 1000000 1654 ns/op 320 B/op 5 allocs/op -BenchmarkMiddlewareBlackNetParse-2 1000000 1989 ns/op 360 B/op 20 allocs/op -PASS -ok command-line-arguments 6.919s -*/ - -var ips []string = []string{ - "10.0.0.0/4", "127.0.0.1/8", "192.168.1.0/24", "192.168.75.0/24", "192.168.100.0/24", -} - -var requests []uint64 = []uint64{ - 725415979, 2727437335, 889276411, 4005535794, 3864288534, 3906172701, 282878927, 1284469666, 730935782, 3371086418, - 1506312450, 1351422527, 1427742110, 1787801507, 2252116061, 229145224, 2463885032, 977944943, 3785363053, 3752670878, - 1109101831, 523139815, 2692892509, 822628332, 1521829731, 1137604504, 3946127316, 3492727158, 3701842868, 1345785201, - 2479587981, 1525387624, 2335875430, 2742578379, 842531784, 4164034788, 4067025409, 3579565778, 1135250289, 2272239320, - 2221887036, 47163049, 756685807, 3064055796, 2298095091, 3099116819, 4070972416, 1014033, 3023215026, 555430525, - 3702021454, 2340802113, 2507760403, 510831888, 3073321492, 4221140315, 1198583294, 1495418697, 827583711, 813333453, - 2746343126, 3755199452, 1697814659, 365059279, 3478405321, 2147566177, 281339662, 2742376600, 2293307920, 2061663865, - 913999062, 542572186, 4225265321, 633066366, 2063795404, 522841846, 195572401, 124532676, 2456662794, 3902204181, - 2491401143, 4233234751, 69766498, 388520887, 1017105985, 62871287, 3328355052, 1705168586, 2260082173, 3340006743, - 2211140888, 1906467873, 1247205260, 1492905294, 1014862918, 2587182986, 1040587870, 3570772999, 3084952258, 2425691705, -} - -var requeststrs []string = []string{ - "43.60.248.43", "162.145.100.23", "53.1.71.251", "238.191.160.50", "230.84.93.22", "232.211.119.29", "16.220.99.207", "76.143.115.162", "43.145.49.230", "200.238.178.82", - "89.200.129.2", "80.141.18.63", "85.25.157.158", "106.143.175.163", "134.60.144.93", "13.168.122.136", "146.219.230.232", "58.74.65.111", "225.160.14.109", "223.173.54.158", - "66.27.141.7", "31.46.122.231", "160.130.71.93", "49.8.79.236", "90.181.71.99", "67.206.119.152", "235.53.31.212", "208.46.201.118", "220.165.163.180", "80.55.13.113", - "147.203.130.141", "90.235.145.104", "139.58.161.102", "163.120.108.203", "50.56.3.200", "248.50.32.228", "242.105.226.1", "213.91.214.210", "67.170.139.113", "135.111.158.216", - "132.111.78.60", "2.207.166.169", "45.26.27.239", "182.161.199.244", "136.250.37.243", "184.184.197.19", "242.166.28.0", "0.15.121.17", "180.50.153.178", "33.27.50.125", - "220.168.93.78", "139.133.206.65", "149.121.99.19", "30.114.173.16", "183.47.42.20", "251.153.125.91", "71.112.237.254", "89.34.71.73", "49.83.236.223", "48.122.123.205", - "163.177.222.214", "223.211.203.220", "101.50.152.131", "21.194.92.207", "207.84.64.201", "128.1.66.97", "16.196.231.14", "163.117.88.152", "136.177.26.16", "122.226.126.121", - "54.122.132.214", "32.86.254.154", "251.216.110.169", "37.187.211.126", "123.3.4.204", "31.41.238.246", "11.168.50.177", "7.108.55.196", "146.109.179.10", "232.150.233.21", - "148.127.195.183", "252.82.9.63", "4.40.141.98", "23.40.91.183", "60.159.206.65", "3.191.86.247", "198.98.170.236", "101.162.206.202", "134.182.29.253", "199.20.117.87", - "131.203.85.24", "113.162.100.33", "74.86.215.140", "88.251.237.78", "60.125.148.70", "154.53.71.138", "62.6.28.94", "212.213.172.7", "183.224.162.194", "144.149.30.57", -} - -/* -func TestMiddlewareBlackResult(t *testing.T) { - tree := new(middleware.BlackNode) - array := new(BlackNodeArray) - for _, ip := range ips { - tree.Insert(ip) - array.Insert(ip) - } - for _, ip := range requests { - if tree.Look(ip) != array.Look(ip) { - t.Logf("tree: %t array: %t result not equal %d %s", tree.Look(ip), array.Look(ip), ip, int2ip(ip)) - } - } -} - -func BenchmarkMiddlewareBlackTree(b *testing.B) { - node := new(middleware.BlackNode) - for _, ip := range ips { - node.Insert(ip) - } - b.ReportAllocs() - for i := 0; i < b.N; i++ { - for _, ip := range requests { - node.Look(ip) - } - } -} -*/ - -func BenchmarkMiddlewareBlackArray(b *testing.B) { - node := new(BlackNodeArray) - b.ReportAllocs() - for _, ip := range ips { - node.Insert(ip) - } - for i := 0; i < b.N; i++ { - for _, ip := range requests { - node.Look(ip) - } - } -} - -func TestMiddlewareBlackParseip(t *testing.T) { - for _, ip := range ips { - ip1, bit1 := ip2intbit(ip) - ip2, bit2 := ip2netintbit(ip) - if ip1 != ip2 || bit1 != bit2 { - t.Log("ip parse error", ip, ip1, ip2, bit1, bit2) - } - } -} - -func BenchmarkMiddlewareBlackIp2intbit(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - for _, ip := range ips { - ip2intbit(ip) - } - } -} - -func BenchmarkMiddlewareBlackNetParse(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - for _, ip := range ips { - ip2netintbit(ip) - } - } -} - -// BlackNodeArray 定义数组遍历实现ip解析 -type BlackNodeArray struct { - Data []uint64 - Mask []uint - Count []uint64 -} - -// Insert 方法给黑名单节点新增一个ip或ip段。 -func (node *BlackNodeArray) Insert(ip string) { - iip, bit := ip2intbit(ip) - node.Data = append(node.Data, iip>>(32-bit)) - node.Mask = append(node.Mask, 32-bit) - node.Count = append(node.Count, 0) -} - -// Look 方法匹配ip是否在黑名单节点,命中则节点计数加一。 -func (node *BlackNodeArray) Look(ip uint64) bool { - for i := range node.Data { - if node.Data[i] == (ip >> node.Mask[i]) { - node.Count[i]++ - return true - } - } - return false -} - -// BlackNodeArrayNet 定义基于net库实现ip遍历匹配,支持ipv6. -type BlackNodeArrayNet struct { - Data []net.IP - Mask []net.IPMask - Count []uint64 -} - -// Insert 方法给黑名单节点新增一个ip或ip段。 -func (node *BlackNodeArrayNet) Insert(ip string) { - _, ipnet, _ := net.ParseCIDR(ip) - node.Data = append(node.Data, ipnet.IP) - node.Mask = append(node.Mask, ipnet.Mask) - node.Count = append(node.Count, 0) -} - -// Look 方法匹配ip是否在黑名单节点,命中则节点计数加一。 -func (node *BlackNodeArrayNet) Look(ip string) bool { - netip := net.ParseIP(ip) - for i := range node.Data { - if node.Data[i].Equal(netip.Mask(node.Mask[i])) { - node.Count[i]++ - return true - } - } - return false -} - -func ip2netintbit(ip string) (uint64, uint) { - ipaddr, ipnet, _ := net.ParseCIDR(ip) - length := len(ipaddr) - bit, _ := ipnet.Mask.Size() - var sum uint64 - sum += uint64(ipaddr[length-4]) << 24 - sum += uint64(ipaddr[length-3]) << 16 - sum += uint64(ipaddr[length-2]) << 8 - sum += uint64(ipaddr[length-1]) - return sum, uint(bit) -} - -func ip2intbit(ip string) (uint64, uint) { - bit := 32 - pos := strings.Index(ip, "/") - if pos != -1 { - bit, _ = strconv.Atoi(ip[pos+1:]) - ip = ip[:pos] - } - return ip2int(ip), uint(bit) -} - -func ip2int(ip string) uint64 { - bits := strings.Split(ip, ".") - b0, _ := strconv.Atoi(bits[0]) - b1, _ := strconv.Atoi(bits[1]) - b2, _ := strconv.Atoi(bits[2]) - b3, _ := strconv.Atoi(bits[3]) - - var sum uint64 - sum += uint64(b0) << 24 - sum += uint64(b1) << 16 - sum += uint64(b2) << 8 - sum += uint64(b3) - return sum -} - -func int2ip(ip uint64) string { - var bytes [4]uint64 - bytes[0] = ip & 0xFF - bytes[1] = (ip >> 8) & 0xFF - bytes[2] = (ip >> 16) & 0xFF - bytes[3] = (ip >> 24) & 0xFF - return fmt.Sprintf("%d.%d.%d.%d", bytes[3], bytes[2], bytes[1], bytes[0]) -} - -func BenchmarkMiddlewareRewrite(b *testing.B) { - rewritedata := map[string]string{ - "/js/*": "/public/js/$0", - "/api/v1/users/*/orders/*": "/api/v3/user/$0/order/$1", - "/d/*": "/d/$0-$0", - "/api/v1/*": "/api/v3/$0", - "/api/v2/*": "/api/v3/$0", - "/help/history*": "/api/v3/history", - "/help/history": "/api/v3/history", - "/help/*": "$0", - } - - app := eudore.NewApp() - app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) - app.AddMiddleware("global", middleware.NewRewriteFunc(rewritedata)) - app.AnyFunc("/*", eudore.HandlerEmpty) - paths := []string{"/", "/js/", "/js/index.js", "/api/v1/user", "/api/v1/user/new", "/api/v1/users/v3/orders/8920", "/api/v1/users/orders", "/api/v2", "/api/v2/user", "/d/3", "/help/history", "/help/historyv2"} - w, r := httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil) - b.ReportAllocs() - for i := 0; i < b.N; i++ { - for _, path := range paths { - r.URL.Path = path - app.ServeHTTP(w, r) - } - } -} -func BenchmarkMiddlewareRewriteWithZero(b *testing.B) { - app := eudore.NewApp() - app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) - app.AnyFunc("/*", eudore.HandlerEmpty) - paths := []string{"/", "/js/", "/js/index.js", "/api/v1/user", "/api/v1/user/new", "/api/v1/users/v3/orders/8920", "/api/v1/users/orders", "/api/v2", "/api/v2/user", "/d/3", "/help/history", "/help/historyv2"} - w, r := httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil) - b.ReportAllocs() - for i := 0; i < b.N; i++ { - for _, path := range paths { - r.URL.Path = path - app.ServeHTTP(w, r) - } - } -} - -func BenchmarkMiddlewareRewriteWithRouter(b *testing.B) { - routerdata := map[string]interface{}{ - "/js/*0": newRewriteFunc("/public/js/$0"), - "/api/v1/users/:0/orders/*1": newRewriteFunc("/api/v3/user/$0/order/$1"), - "/d/*0": newRewriteFunc("/d/$0-$0"), - "/api/v1/*0": newRewriteFunc("/api/v3/$0"), - "/api/v2/*0": newRewriteFunc("/api/v3/$0"), - "/help/history*0": newRewriteFunc("/api/v3/history"), - "/help/history": newRewriteFunc("/api/v3/history"), - "/help/*0": newRewriteFunc("$0"), - } - app := eudore.NewApp() - app.SetValue(eudore.ContextKeyLogger, eudore.NewLoggerInit()) - app.AddMiddleware("global", middleware.NewRouterFunc(routerdata)) - app.AnyFunc("/*", eudore.HandlerEmpty) - paths := []string{"/", "/js/", "/js/index.js", "/api/v1/user", "/api/v1/user/new", "/api/v1/users/v3/orders/8920", "/api/v1/users/orders", "/api/v2", "/api/v2/user", "/d/3", "/help/history", "/help/historyv2"} - w, r := httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil) - b.ReportAllocs() - for i := 0; i < b.N; i++ { - for _, path := range paths { - r.URL.Path = path - app.ServeHTTP(w, r) - } - } -} - -func newRewriteFunc(path string) eudore.HandlerFunc { - paths := strings.Split(path, "$") - Index := make([]string, 1, len(paths)*2-1) - Data := make([]string, 1, len(paths)*2-1) - Index[0] = "" - Data[0] = paths[0] - for _, path := range paths[1:] { - Index = append(Index, path[0:1]) - Data = append(Data, "") - if path[1:] != "" { - Index = append(Index, "") - Data = append(Data, path[1:]) - } - } - return func(ctx eudore.Context) { - buffer := bytes.NewBuffer(nil) - for i := range Index { - if Index[i] == "" { - buffer.WriteString(Data[i]) - } else { - buffer.WriteString(ctx.GetParam(Index[i])) - } - } - ctx.Request().URL.Path = buffer.String() - } -} diff --git a/_example/otherNotify.go b/_example/otherNotify.go new file mode 100644 index 0000000..8fef40d --- /dev/null +++ b/_example/otherNotify.go @@ -0,0 +1,219 @@ +package main + +/* +go build -o ~/go/bin/gonotify otherNotify.go + +先app.Config设置notify配置,然后启动notify。 +如果是notify的程序可以通过环境变量eudore.EnvEudoreIsNotify检测。 +当程序启动时会如果eudore.EnvEudoreIsNotify不存在,则使用notify开始监听阻塞app后续初始化,否在就忽略notify然后进行正常app启动。 + +实现原理基于fsnotify检测目录内go文件变化,然后执行编译命令,如果编译成功就kill原进程并执行启动命令。 + +其他类似工具:air +*/ + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "github.com/eudore/eudore" + "github.com/fsnotify/fsnotify" +) + +var startcmd string + +func init() { + if runtime.GOOS == "windows" { + startcmd = "powershell" + } else { + startcmd = "bash" + } +} + +// Notify 定义监听重启对象。 +type App struct { + sync.Mutex + *NotifyConfig + *eudore.App + Watcher *fsnotify.Watcher + lastBuild context.CancelFunc + lastProcess context.CancelFunc +} + +type NotifyConfig struct { + Workdir string `json:"workdir" alias:"workdir"` + Command string `json:"command" alias:"command"` + Pidfile string `json:"pidfile" alias:"pidfile"` + Build string `json:"build" alias:"build"` + Start string `json:"start" alias:"start"` + Watch string `json:"watch" alias:"watch"` +} + +func main() { + app := NewApp() + app.Parse() + go app.Run() + app.App.Run() +} + +// NewApp 函数创建一个Notify对象。 +func NewApp() *App { + conf := &NotifyConfig{ + Build: "", + Start: "go run .", + Watch: ".", + } + app := &App{ + NotifyConfig: conf, + App: eudore.NewApp(), + } + app.SetValue(eudore.ContextKeyConfig, eudore.NewConfig(conf)) + app.ParseOption( + eudore.DefaultConfigAllParseFunc, + app.NewParseWatcherFunc(), + ) + return app +} + +func (app *App) NewParseWatcherFunc() eudore.ConfigParseFunc { + return func(context.Context, eudore.Config) error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + app.Watcher = watcher + return nil + } +} + +// Run 方法启动Notify。 +// +// 调用App.Logger +func (n *App) Run() { + n.App.Info("notify buildCmd", n.Build) + n.App.Info("notify startCmd", n.Start) + for _, path := range strings.Split(n.Watch, ";") { + n.WatchAll(strings.TrimSpace(path)) + } + + n.buildAndRestart() + + var timer = time.AfterFunc(1000*time.Hour, n.buildAndRestart) + defer func() { + timer.Stop() + if n.lastBuild != nil { + n.lastBuild() + } + if n.lastProcess != nil { + n.lastProcess() + } + }() + + for { + select { + case event, ok := <-n.Watcher.Events: + if !ok { + return + } + + // 监听go文件写入 + if event.Name[len(event.Name)-3:] == ".go" && event.Op&fsnotify.Write == fsnotify.Write { + n.App.Debug("modified file:", event.Name) + + // 等待0.1秒执行更新,防止短时间大量触发 + timer.Reset(100 * time.Millisecond) + } + case err, ok := <-n.Watcher.Errors: + if !ok { + return + } + n.App.Error("notify watcher error:", err) + case <-n.App.Done(): + return + } + } +} + +func (n *App) buildAndRestart() { + if n.Build != "" { + // 取消上传编译 + n.Lock() + if n.lastBuild != nil { + n.lastBuild() + } + ctx, cannel := context.WithCancel(n.App.Context) + n.lastBuild = cannel + n.Unlock() + // 执行编译命令 + cmd := exec.CommandContext(ctx, startcmd, "-c", n.Build) + cmd.Env = os.Environ() + body, err := cmd.CombinedOutput() + if err != nil { + fmt.Printf("notify build error: \n%s", body) + n.App.Errorf("notify build error: %s", body) + return + } + } + n.App.Info("notify build success, restart process...") + time.Sleep(10 * time.Millisecond) + // 重启子进程 + n.restart() +} + +func (n *App) restart() { + // 关闭旧进程 + n.Lock() + if n.lastProcess != nil { + n.lastProcess() + } + ctx, cannel := context.WithCancel(n.App.Context) + n.lastProcess = cannel + n.Unlock() + // 启动新进程 + cmd := exec.CommandContext(ctx, startcmd, "-c", n.Start) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = os.Environ() + err := cmd.Start() + if err != nil { + n.App.Error("notify start error:", err) + } +} + +// WatchAll 方法添加一个文件或目录,如果/结尾的目录会递归监听子目录。 +func (n *App) WatchAll(path string) { + if path != "" { + // 递归目录处理 + listDir(path, n.watch) + n.watch(path) + } +} + +func (n *App) watch(path string) { + n.App.Debug("notify add watch dir " + path) + err := n.Watcher.Add(path) + if err != nil { + n.App.Error(err) + } +} + +func listDir(path string, fn func(string)) { + files, _ := ioutil.ReadDir(path) + for _, f := range files { + // 忽略隐藏目录,例如: .git + if f.IsDir() && f.Name()[0] != '.' { + path := filepath.Join(path, f.Name()) + fn(path) + listDir(path, fn) + } + } +} diff --git a/_example/policy_test.go b/_example/policy_test.go index 8853fb6..2026347 100644 --- a/_example/policy_test.go +++ b/_example/policy_test.go @@ -3,6 +3,7 @@ package eudore_test import ( "fmt" "net/http" + "net/url" "testing" "time" @@ -56,22 +57,22 @@ func TestPolicyPbacParse(t *testing.T) { now := time.Now().Add(time.Hour).Unix() app.NewRequest(nil, "GET", "/") - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, "000")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, "Bearer 000")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyaWQiOjEwLCJwb2xpY3kiOiJiYXNlNjQiLCJleHBpcmF0aW9uIjoxNjQ5MTQwMzkwfQ.2mqeTZZizrP")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, `Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.{"userid":10,"expiration":1649140575}.ffikvNJyZVA8u01PtZ_3fUwQJQ5aGjw_0uCKhoKDr9w`)) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, `Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyaWQiOiIxMCIsImV4cGlyYXRpb24iOjE2NDkxNDA1NzV9.LgfnJJ-UknB1hOJIA1FrYbpeCNJ2cRuSj_r_bJo8vA8`)) - - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", time.Now().Add(time.Hour*-1).Unix()))) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", now))) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, "Bearer "+pbac.Signaturer.Signed(&policy.SignatureUser{UserID: 10, Policy: "base64", Expiration: now}))) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "base64", now))) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, `[{"effect":true,"action":["Home"]}]`, now))) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, `[{"effect":false,"action":["Home"]}]`, now))) - app.NewRequest(nil, "GET", "/", eudore.NewClientHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, `[{"effect":true,"action":["Index"]}]`, now))) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, "000")) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, "Bearer 000")) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyaWQiOjEwLCJwb2xpY3kiOiJiYXNlNjQiLCJleHBpcmF0aW9uIjoxNjQ5MTQwMzkwfQ.2mqeTZZizrP")) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, `Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.{"userid":10,"expiration":1649140575}.ffikvNJyZVA8u01PtZ_3fUwQJQ5aGjw_0uCKhoKDr9w`)) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, `Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyaWQiOiIxMCIsImV4cGlyYXRpb24iOjE2NDkxNDA1NzV9.LgfnJJ-UknB1hOJIA1FrYbpeCNJ2cRuSj_r_bJo8vA8`)) + + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", "", time.Now().Add(time.Hour*-1).Unix()))) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", "", now))) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, "Bearer "+pbac.Signaturer.Signed(&policy.SignatureUser{UserID: 10, Policy: "base64", Expiration: now}))) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", "base64", now))) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", `[{"effect":true,"action":["Home"]}]`, now))) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", `[{"effect":false,"action":["Home"]}]`, now))) + app.NewRequest(nil, "GET", "/", eudore.NewClientOptionHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", `[{"effect":true,"action":["Index"]}]`, now))) option := func(req *http.Request) { - req.Header.Set(eudore.HeaderAuthorization, pbac.NewBearer(10, "", now)) + req.Header.Set(eudore.HeaderAuthorization, pbac.NewBearer(10, "", "", now)) } app.NewRequest(nil, "GET", "/", option) app.NewRequest(nil, "GET", "/static/1.js", option) @@ -87,7 +88,7 @@ func TestPolicyPbacHandler(t *testing.T) { app.AddMiddleware(middleware.NewLoggerFunc(app, "route", "action", "resource", "Userid")) app.AddMiddleware(pbac) - for i := 1; i < 11; i++ { + for i := 1; i < 13; i++ { app.AnyFunc(fmt.Sprintf("/%d action=Num:%d", i, i), eudore.HandlerEmpty) pbac.AddMember(&policy.Member{UserID: 10, PolicyID: i}) } @@ -122,7 +123,7 @@ func TestPolicyPbacHandler(t *testing.T) { Statement []Statement } var data Data - data.UserID = eudore.GetStringInt(ctx.GetParam(eudore.ParamUserid)) + data.UserID = eudore.GetAnyByString[int](ctx.GetParam(eudore.ParamUserid)) err := ctx.Bind(&data) if err != nil { return nil, err @@ -141,8 +142,8 @@ func TestPolicyPbacHandler(t *testing.T) { } return resps, nil }) - pbac.AddMember(&policy.Member{UserID: 10, PolicyID: 11, Expiration: time.Now().Add(time.Hour)}) - pbac.AddMember(&policy.Member{UserID: 10, PolicyID: 11, Expiration: time.Now().Add(time.Hour * -1)}) + pbac.AddMember(&policy.Member{UserID: 10, PolicyID: 13, Expiration: time.Now().Add(time.Hour)}) + pbac.AddMember(&policy.Member{UserID: 10, PolicyID: 13, Expiration: time.Now().Add(time.Hour * -1)}) pbac.AddPolicyString(`{"policy_id":1,"statement":[{"effect":true,"action":["Num:1"],"conditions":{"and":{"method":["GET"],"sourceip":["127.0.0.1"]}}}]}`) pbac.AddPolicyString(`{"policy_id":2,"statement":[{"effect":true,"action":["Num:2"],"conditions":{"or":{"method":["GET"],"sourceip":["127.0.0.1"]}}}]}`) @@ -150,26 +151,28 @@ func TestPolicyPbacHandler(t *testing.T) { pbac.AddPolicyString(`{"policy_id":4,"statement":[{"effect":true,"action":["Num:4"],"conditions":{"date":{"before":"2030-12-31"}}}]}`) pbac.AddPolicyString(`{"policy_id":5,"statement":[{"effect":true,"action":["Num:5"],"conditions":{"time":{"before":"23:59:59"}}}]}`) pbac.AddPolicyString(`{"policy_id":6,"statement":[{"effect":true,"action":["Num:6"],"conditions":{"method":["GET"]}}]}`) - pbac.AddPolicyString(`{"policy_id":7,"statement":[{"effect":true,"action":["Num:7"],"conditions":{"params":{"action":["Num:7"]}}}]}`) - pbac.AddPolicyString(`{"policy_id":8,"statement":[{"effect":false,"action":["Num:8"]}]}`) - pbac.AddPolicyString(`{"policy_id":9,"statement":[{"effect":true,"action":["Menu"],"data":{"menu":["Home"]}}]}`) - pbac.AddPolicyString(`{"policy_id":10,"statement":[{"effect":true,"action":["Menu"],"data":{"menu":["Index"]}}]}`) - pbac.AddPolicyString(`{"policy_id":12}`) - pbac.AddPolicyString(`{"policy_id":13,}`) - - app.Client = app.WithClient(eudore.NewClientHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", time.Now().Add(time.Hour).Unix()))) + pbac.AddPolicyString(`{"policy_id":7,"statement":[{"effect":true,"action":["Num:7"],"conditions":{"path":["/"]}}]}`) + pbac.AddPolicyString(`{"policy_id":8,"statement":[{"effect":true,"action":["Num:8"],"conditions":{"path":["/8"]}}]}`) + pbac.AddPolicyString(`{"policy_id":9,"statement":[{"effect":true,"action":["Num:9"],"conditions":{"params":{"action":["Num:9"]}}}]}`) + pbac.AddPolicyString(`{"policy_id":10,"statement":[{"effect":false,"action":["Num:10"]}]}`) + pbac.AddPolicyString(`{"policy_id":11,"statement":[{"effect":true,"action":["Menu"],"data":{"menu":["Home"]}}]}`) + pbac.AddPolicyString(`{"policy_id":12,"statement":[{"effect":true,"action":["Menu"],"data":{"menu":["Index"]}}]}`) + pbac.AddPolicyString(`{"policy_id":13}`) + pbac.AddPolicyString(`{"policy_id":14,}`) + + app.Client = app.WithClient(eudore.NewClientOptionHeader(eudore.HeaderAuthorization, pbac.NewBearer(10, "", "", time.Now().Add(time.Hour).Unix()))) app.NewRequest(nil, "GET", "/1") app.NewRequest(nil, "PUT", "/1") app.NewRequest(nil, "GET", "/2") app.NewRequest(nil, "PUT", "/2") - app.NewRequest(nil, "GET", "/3", eudore.NewClientHeader(eudore.HeaderXRealIP, "127.0.0.1")) - app.NewRequest(nil, "GET", "/3", eudore.NewClientHeader(eudore.HeaderXRealIP, "172.17.1.3")) + app.NewRequest(nil, "GET", "/3", eudore.NewClientOptionHeader(eudore.HeaderXRealIP, "127.0.0.1")) + app.NewRequest(nil, "GET", "/3", eudore.NewClientOptionHeader(eudore.HeaderXRealIP, "172.17.1.3")) app.NewRequest(nil, "GET", "/4") app.NewRequest(nil, "GET", "/5") app.NewRequest(nil, "PUT", "/6") - app.NewRequest(nil, "GET", "/6") app.NewRequest(nil, "GET", "/7") app.NewRequest(nil, "GET", "/8") + app.NewRequest(nil, "GET", "/9") app.NewRequest(nil, "GET", "/menu") app.NewRequest(nil, "PUT", "/has", eudore.NewClientBodyJSON(map[string]interface{}{"action": "Menu", "resource": "/has"})) app.NewRequest(nil, "GET", "/runtime") @@ -187,6 +190,8 @@ func (*User023Controller) Get(eudore.Context) {} func (*User023Controller) GetIcon(eudore.Context) {} func TestPolicyutil(t *testing.T) { + policy.NewSignaturerJwt(nil).Signed(make(chan int)) + pbac := policy.NewPolicys() pbac.ActionFunc = func(ctx eudore.Context) string { return ctx.GetQuery("action") } pbac.AddPolicyString(`{ @@ -218,14 +223,14 @@ func TestPolicyutil(t *testing.T) { app.AddController(&User023Controller{}) app.Info((User023Controller{}).ControllerParam("github.com/eudore/eudore", "User023Controller", "Get")) - app.NewRequest(nil, "GET", "/", eudore.NewClientQuery("action", "eudore:user:Get1")) - app.NewRequest(nil, "GET", "/", eudore.NewClientQuery("action", "eudore:user:Get2")) - app.NewRequest(nil, "GET", "/", eudore.NewClientQuery("action", "eudore:user:Get3")) - app.NewRequest(nil, "GET", "/", eudore.NewClientQuery("action", "eudore:group:22")) - app.NewRequest(nil, "GET", "/", eudore.NewClientQuery("action", "eudore:group:")) - app.NewRequest(nil, "GET", "/", eudore.NewClientQuery("action", "eudore")) - app.NewRequest(nil, "GET", "/", eudore.NewClientQuery("action", "eudore:ns:Get")) - app.NewRequest(nil, "GET", "/", eudore.NewClientQuery("action", "eudore:ns:Get2")) + app.NewRequest(nil, "GET", "/", url.Values{"action": {"eudore:user:Get1"}}) + app.NewRequest(nil, "GET", "/", url.Values{"action": {"eudore:user:Get2"}}) + app.NewRequest(nil, "GET", "/", url.Values{"action": {"eudore:user:Get3"}}) + app.NewRequest(nil, "GET", "/", url.Values{"action": {"eudore:group:22"}}) + app.NewRequest(nil, "GET", "/", url.Values{"action": {"eudore:group:"}}) + app.NewRequest(nil, "GET", "/", url.Values{"action": {"eudore"}}) + app.NewRequest(nil, "GET", "/", url.Values{"action": {"eudore:ns:Get"}}) + app.NewRequest(nil, "GET", "/", url.Values{"action": {"eudore:ns:Get2"}}) app.CancelFunc() app.Run() diff --git a/_example/protobuf_test.go b/_example/protobuf_test.go index da80867..09aa8ae 100644 --- a/_example/protobuf_test.go +++ b/_example/protobuf_test.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "io" + "net/http" "testing" "time" @@ -220,10 +221,10 @@ func TestProtoBufHandlerData(t *testing.T) { }) app.NewRequest(nil, "PUT", "/", - eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationProtobuf), - eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationProtobuf), - eudore.NewClientBody(nil), - eudore.NewClientDumpHead(), + http.Header{ + eudore.HeaderAccept: []string{eudore.MimeApplicationProtobuf}, + eudore.HeaderContentType: []string{eudore.MimeApplicationProtobuf}, + }, ) app.CancelFunc() diff --git a/_example/routerDelete.go b/_example/routerDelete.go index 8948bd2..8ef815d 100644 --- a/_example/routerDelete.go +++ b/_example/routerDelete.go @@ -8,15 +8,13 @@ package main import ( "github.com/eudore/eudore" - "github.com/eudore/eudore/component/httptest" "github.com/eudore/eudore/middleware" ) func main() { app := eudore.NewApp() - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(eudore.NewRouterCoreLock(nil))) - - client := httptest.NewClient(app) + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(eudore.NewRouterCoreLock(nil))) + client := app.WithClient(eudore.NewClientCheckStatus(200)) register := app.Group(" register=off") app.AnyFunc("/version", echoStringHandler("any version")) @@ -24,11 +22,11 @@ func main() { app.AnyFunc("/version", echoStringHandler("any version")) app.AnyFunc("/version1", echoStringHandler("any version")) - client.NewRequest("GET", "/version").Do().CheckStatus(200).CheckBodyString("any version") + client.NewRequest(nil, "GET", "/version", eudore.NewClientCheckBody("any version")) app.GetFunc("/version", echoStringHandler("get version")) - client.NewRequest("GET", "/version").Do().CheckStatus(200).CheckBodyString("get version") + client.NewRequest(nil, "GET", "/version", eudore.NewClientCheckBody("get version")) register.AddHandler("GET,POST", "/version", echoStringHandler("get version")) - client.NewRequest("GET", "/version").Do().CheckStatus(200).CheckBodyString("any version") + client.NewRequest(nil, "GET", "/version", eudore.NewClientCheckBody("any version")) register.AnyFunc("/version*", echoStringHandler("any version")) register.AnyFunc("/version0", echoStringHandler("any version")) register.AnyFunc("/version2", echoStringHandler("any version")) @@ -43,17 +41,17 @@ func main() { // ---------------- 测试 ---------------- - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(eudore.NewRouterCoreLock(nil))) + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(eudore.NewRouterCoreLock(nil))) register = app.Group(" register=off") app.AnyFunc("/eudore/debug/look/*", middleware.NewLookFunc(app)) app.AnyFunc("/version", echoStringHandler("any version")) app.AnyFunc("/version1", echoStringHandler("any version")) app.AnyFunc("/version2", echoStringHandler("any version")) - client.NewRequest("GET", "/version").Do().CheckStatus(200).CheckBodyString("any version") + client.NewRequest(nil, "GET", "/version", eudore.NewClientCheckBody("any version")) app.GetFunc("/version", echoStringHandler("get version")) - client.NewRequest("GET", "/version").Do().CheckStatus(200).CheckBodyString("get version") + client.NewRequest(nil, "GET", "/version", eudore.NewClientCheckBody("get version")) register.AddHandler("GET,POST", "/version", echoStringHandler("get version")) - client.NewRequest("GET", "/version").Do().CheckStatus(200).CheckBodyString("any version") + client.NewRequest(nil, "GET", "/version", eudore.NewClientCheckBody("any version")) register.GetFunc("/version", echoStringHandler("get version")) register.AnyFunc("/version*", echoStringHandler("any version")) register.AnyFunc("/version0", echoStringHandler("any version")) @@ -71,10 +69,10 @@ func main() { app.AnyFunc("/api/v1/user/id/:id", eudore.HandlerEmpty) app.AnyFunc("/api/v1/user/name/*name", eudore.HandlerEmpty) - app.AnyFunc("/api/v1/user/:id|isnum", eudore.HandlerEmpty) + app.AnyFunc("/api/v1/user/:id|num", eudore.HandlerEmpty) app.AnyFunc("/api/v1/user/*name|nozero", eudore.HandlerEmpty) - register.AnyFunc("/api/v1/user/:id|isnum/", eudore.HandlerEmpty) - register.AnyFunc("/api/v1/user/:id|isnum", eudore.HandlerEmpty) + register.AnyFunc("/api/v1/user/:id|num/", eudore.HandlerEmpty) + register.AnyFunc("/api/v1/user/:id|num", eudore.HandlerEmpty) register.AnyFunc("/api/v1/user/*name|nozero", eudore.HandlerEmpty) register.AnyFunc("/api/v1/user/id/:id", eudore.HandlerEmpty) register.AnyFunc("/api/v1/user/name/*name", eudore.HandlerEmpty) diff --git a/_example/routerStd.go b/_example/routerStd.go index 4aba127..47ad884 100644 --- a/_example/routerStd.go +++ b/_example/routerStd.go @@ -6,7 +6,7 @@ RouterStd是eudore的默认路由器,使用基数树算法独立实现,性 具有路由匹配优先级: 常量匹配 > 变量校验匹配 >变量匹配 > 通配符校验匹配 > 通配符匹配 方法优先级: 具体方法 > Any方法 -用法:在正常变量和通配符后,使用'|'符号分割,后为校验规则,isnum是校验函数;{min:100}为动态检验函数,min是动态校验函数名称,':'后为参数;如果为'^'开头为正则校验,并且要使用'$'作为结尾。 +用法:在正常变量和通配符后,使用'|'符号分割,后为校验规则,num是校验函数;{min:100}为动态检验函数,min是动态校验函数名称,':'后为参数;如果为'^'开头为正则校验,并且要使用'$'作为结尾。 在路径中使用'{}'包裹的一段字符串为块模式,切分时将整块紧跟上一个字符串,这样允许在校验规则内使用任何字符, 字符空格、冒号、星号、前花括号、后花括号、斜杠均为特殊符号(' '、':'、'*'、'{'、'}'、'/'),一定需要使用块模式包裹字符串。 @@ -15,10 +15,10 @@ RouterStd是eudore的默认路由器,使用基数树算法独立实现,性 例如路径切割的切片,首字符为':'是变量匹配,首字符为'*'是通配符匹配,其他都是常量字符串匹配。 变量匹配从当前到下一个斜杠('/')处或结尾,通配符匹配当前位置到结尾,常量匹配对应的字符串。 ``` -:num|isnum +:num|num :num|{min:100} :num|{^0.*$} -*num|isnum +*num|num *num|{min:100} *num|{^0.*$} ``` @@ -26,18 +26,14 @@ RouterStd是eudore的默认路由器,使用基数树算法独立实现,性 import ( "github.com/eudore/eudore" - "github.com/eudore/eudore/component/httptest" "github.com/eudore/eudore/middleware" ) func main() { - // 默认路由器就是 NewRouterStd(nil) app := eudore.NewApp() - app.AnyFunc("/eudore/debug/look/*", middleware.NewLookFunc(app.Router)) + app.AddMiddleware(middleware.NewLoggerFunc(app, "route")) + app.AnyFunc("/eudore/debug/meta/*", eudore.HandlerMetadata) - app.AddMiddleware(func(ctx eudore.Context) { - ctx.WriteString("route: " + ctx.GetParam("route") + "\n") - }) app.GetFunc("/get/:name", func(ctx eudore.Context) { ctx.WriteString("get name: " + ctx.GetParam("name") + "\n") }) @@ -63,31 +59,34 @@ func main() { app.GetFunc("/get/:num|{min:100}", func(ctx eudore.Context) { ctx.WriteString("num great 100, num is: " + ctx.GetParam("num") + "\n") }) - // 校验函数,使用校验函数isnum。 - app.GetFunc("/get/:num|isnum", func(ctx eudore.Context) { - ctx.WriteString("isnum num is: " + ctx.GetParam("num") + "\n") + // 校验函数,使用校验函数num。 + app.GetFunc("/get/:num|num", func(ctx eudore.Context) { + ctx.WriteString("num num is: " + ctx.GetParam("num") + "\n") }) // 通配符研究不写了,和变量校验相同。 app.GetFunc("/*path|{^0.*$}", func(ctx eudore.Context) { ctx.WriteString("get path first char is '0', path is: " + ctx.GetParam("path") + "\n") }) + app.AddHandler("TEST", "/:path|{^0.*$}/*path|{^0.*$}", eudore.HandlerRouter404) - app.GetFunc("/:path|haha", eudore.HandlerRouter404) + app.GetFunc("/:path|enum=1,2,3", eudore.HandlerRouter404) // ---------- 分割线 运行测试请求 ---------- // 测试 - client := httptest.NewClient(app) - client.NewRequest("GET", "/get").Do().CheckStatus(200).CheckBodyContainString("get") - client.NewRequest("GET", "/get/ha").Do().CheckStatus(200).CheckBodyContainString("/get/:name") - client.NewRequest("GET", "/get/eudore").Do().CheckStatus(200).CheckBodyContainString("/get/eudore") - client.NewRequest("PUT", "/get/eudore").Do().CheckStatus(405) + status := eudore.NewClientCheckStatus + body := eudore.NewClientCheckBody + client := app.Client + client.NewRequest(nil, "GET", "/get", status(200), body("get")) + client.NewRequest(nil, "GET", "/get/ha", status(200), body("get name")) + client.NewRequest(nil, "GET", "/get/eudore", status(200), body("get eudore")) + client.NewRequest(nil, "PUT", "/get/eudore", status(405)) - client.NewRequest("GET", "/get/2").Do().CheckStatus(200).CheckBodyContainString("isnum") - client.NewRequest("GET", "/get/22").Do().CheckStatus(200).CheckBodyContainString("isnum") - client.NewRequest("GET", "/get/222").Do().CheckStatus(200).CheckBodyContainString("num great 100", "222") - client.NewRequest("GET", "/get/0xx").Do().CheckStatus(200).CheckBodyContainString("first char is '0'", "0xx") - client.NewRequest("XXX", "/get/0xx").Do().CheckStatus(405).Out() + client.NewRequest(nil, "GET", "/get/2", status(200), body("num")) + client.NewRequest(nil, "GET", "/get/22", status(200), body("num")) + client.NewRequest(nil, "GET", "/get/222", status(200), body("num great 100")) + client.NewRequest(nil, "GET", "/get/0xx", status(200), body("first char is '0'")) + client.NewRequest(nil, "XXX", "/get/0xx", status(405)) app.Listen(":8088") // app.CancelFunc() diff --git a/_example/router_test.go b/_example/router_test.go index 90412aa..79657cd 100644 --- a/_example/router_test.go +++ b/_example/router_test.go @@ -22,13 +22,15 @@ func TestRouterStdAdd(t *testing.T) { } app := eudore.NewApp() + app.SetValue(eudore.ContextKeyHandlerExtender, eudore.NewHandlerExtender()) + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(nil)) app.AddHandlerExtend(func(str String) eudore.HandlerFunc { return func(ctx eudore.Context) { ctx.WriteString(str.Data) } }) app.AddMiddleware(middleware.NewRecoverFunc()) - app.AddController(&Test014Controller{}) + app.Group(" loggerkind=middleware").AddController(&Test014Controller{}) api := app.Group("/method") api.AddHandler("TEST", "/*", String{"test"}) @@ -107,7 +109,7 @@ func TestRouterMiddleware2(t *testing.T) { func TestRouterCoreLock(t *testing.T) { app := eudore.NewApp() - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(eudore.NewRouterCoreLock(nil))) + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(eudore.NewRouterCoreLock(nil))) app.Info(app.Router.(interface{ Metadata() interface{} }).Metadata()) app.AddMiddleware(middleware.NewLoggerFunc(app, "route")) app.GetFunc("/", eudore.HandlerEmpty) @@ -117,34 +119,13 @@ func TestRouterCoreLock(t *testing.T) { app.Run() } -func TestRouterCoreDebug(t *testing.T) { - app := eudore.NewApp() - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(eudore.NewRouterCoreDebug(nil))) - app.AddMiddleware(middleware.NewLoggerFunc(app, "route")) - app.GetFunc("/", eudore.HandlerEmpty) - app.GetFunc("/index", eudore.HandlerEmpty) - app.GetFunc("/health", func(ctx eudore.Context) interface{} { - return app.Router.(interface{ Metadata() interface{} }).Metadata() - }) - app.GetFunc("/delete", eudore.HandlerEmpty) - app.GetFunc("/delete") - - app.NewRequest(nil, "GET", "/eudore/debug/router/data", - eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON), - eudore.NewClientCheckStatus(200), - ) - app.NewRequest(nil, "GET", "/health") - app.CancelFunc() - app.Run() -} - func TestRouterCoreHost(t *testing.T) { echoHandleHost := func(ctx eudore.Context) { ctx.WriteString(ctx.GetParam("host")) } app := eudore.NewApp() - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(eudore.NewRouterCoreHost(nil))) + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(eudore.NewRouterCoreHost(nil))) app.AddMiddleware(middleware.NewLoggerFunc(app, "route")) app.AnyFunc("/* host=eudore.com", echoHandleHost) app.AnyFunc("/* host=eudore.com:8088", echoHandleHost) @@ -156,17 +137,20 @@ func TestRouterCoreHost(t *testing.T) { app.AnyFunc("/api/* host=eudore.com,eudore.cn", echoHandleHost) app.AnyFunc("/*", echoHandleHost) + host := func(h string) any { + return eudore.NewClientOptionHost(h) + } app.NewRequest(nil, "GET", "/", eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("eudore.cn"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.cn")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("eudore.com"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.com")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("eudore.com:8088"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.com")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("eudore.com:8089"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.com")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("eudore.net"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.*")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("www.eudore.cn"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("www.*.cn")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("example.com"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("example.com")) - app.NewRequest(nil, "GET", "/", eudore.NewClientHost("www.example"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("")) - app.NewRequest(nil, "GET", "/api/v1", eudore.NewClientHost("example.com"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("*")) - app.NewRequest(nil, "GET", "/api/v1", eudore.NewClientHost("eudore.com"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.com,eudore.cn")) + app.NewRequest(nil, "GET", "/", host("eudore.cn"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.cn")) + app.NewRequest(nil, "GET", "/", host("eudore.com"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.com")) + app.NewRequest(nil, "GET", "/", host("eudore.com:8088"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.com")) + app.NewRequest(nil, "GET", "/", host("eudore.com:8089"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.com")) + app.NewRequest(nil, "GET", "/", host("eudore.net"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.*")) + app.NewRequest(nil, "GET", "/", host("www.eudore.cn"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("www.*.cn")) + app.NewRequest(nil, "GET", "/", host("example.com"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("example.com")) + app.NewRequest(nil, "GET", "/", host("www.example"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("")) + app.NewRequest(nil, "GET", "/api/v1", host("example.com"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("*")) + app.NewRequest(nil, "GET", "/api/v1", host("eudore.com"), eudore.NewClientCheckStatus(200), eudore.NewClientCheckBody("eudore.com,eudore.cn")) app.CancelFunc() app.Run() diff --git a/_example/routerstd_test.go b/_example/routerstd_test.go index 41d1ffa..5a6d1ea 100644 --- a/_example/routerstd_test.go +++ b/_example/routerstd_test.go @@ -54,17 +54,17 @@ func TestRouterStdAny(t *testing.T) { func TestRouterStdCheck(t *testing.T) { app := eudore.NewApp() app.SetValue(eudore.ContextKeyFuncCreator, eudore.NewFuncCreator()) - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(nil)) + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(nil)) - app.AnyFunc("/1/:num|isnum version=1", eudore.HandlerEmpty) + app.AnyFunc("/1/:num|num version=1", eudore.HandlerEmpty) app.AnyFunc("/1/222", eudore.HandlerEmpty) app.AnyFunc("/2/:num|num", eudore.HandlerEmpty) app.AnyFunc("/2/:num|", eudore.HandlerEmpty) app.AnyFunc("/2/:", eudore.HandlerEmpty) - app.AnyFunc("/3/:num|isnum/22", eudore.HandlerEmpty) - app.AnyFunc("/3/:num|isnum/*", eudore.HandlerEmpty) - app.AnyFunc("/4/*num|isnum", eudore.HandlerEmpty) - app.AnyFunc("/4/*num|isnum", eudore.HandlerEmpty) + app.AnyFunc("/3/:num|num/22", eudore.HandlerEmpty) + app.AnyFunc("/3/:num|num/*", eudore.HandlerEmpty) + app.AnyFunc("/4/*num|num", eudore.HandlerEmpty) + app.AnyFunc("/4/*num|num", eudore.HandlerEmpty) app.AnyFunc("/4/*", eudore.HandlerEmpty) app.AnyFunc("/5/*num|num", eudore.HandlerEmpty) app.AnyFunc("/api/v1/2", eudore.HandlerEmpty) @@ -109,7 +109,7 @@ func TestRouterStdDelete(t *testing.T) { } app := eudore.NewApp() - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(eudore.NewRouterCoreLock(nil))) + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(eudore.NewRouterCoreLock(nil))) register := app.Group(" register=off") app.AnyFunc("/version", echoStringHandler("any version")) @@ -140,7 +140,7 @@ func TestRouterStdDelete(t *testing.T) { register.AddHandler("LOCK", "/api/v:v3/*", eudore.HandlerEmpty) // ---------------- 测试 ---------------- - app.SetValue(eudore.ContextKeyRouter, eudore.NewRouterStd(eudore.NewRouterCoreLock(nil))) + app.SetValue(eudore.ContextKeyRouter, eudore.NewRouter(eudore.NewRouterCoreLock(nil))) register = app.Group(" register=off") app.AnyFunc("/eudore/debug/look/*", middleware.NewLookFunc(app)) app.AnyFunc("/version", echoStringHandler("any version")) @@ -168,10 +168,10 @@ func TestRouterStdDelete(t *testing.T) { app.AnyFunc("/api/v1/user/id/:id", eudore.HandlerEmpty) app.AnyFunc("/api/v1/user/name/*name", eudore.HandlerEmpty) - app.AnyFunc("/api/v1/user/:id|isnum", eudore.HandlerEmpty) + app.AnyFunc("/api/v1/user/:id|num", eudore.HandlerEmpty) app.AnyFunc("/api/v1/user/*name|nozero", eudore.HandlerEmpty) - register.AnyFunc("/api/v1/user/:id|isnum/", eudore.HandlerEmpty) - register.AnyFunc("/api/v1/user/:id|isnum", eudore.HandlerEmpty) + register.AnyFunc("/api/v1/user/:id|num/", eudore.HandlerEmpty) + register.AnyFunc("/api/v1/user/:id|num", eudore.HandlerEmpty) register.AnyFunc("/api/v1/user/*name|nozero", eudore.HandlerEmpty) register.AnyFunc("/api/v1/user/id/:id", eudore.HandlerEmpty) register.AnyFunc("/api/v1/user/name/*name", eudore.HandlerEmpty) diff --git a/_example/serverConfig.go b/_example/serverConfig.go deleted file mode 100644 index e3ec6a4..0000000 --- a/_example/serverConfig.go +++ /dev/null @@ -1,16 +0,0 @@ -package main - -/* - */ - -import ( - "github.com/eudore/eudore" -) - -func main() { - app := eudore.NewApp() - - app.Listen(":8088") - // app.CancelFunc() - app.Run() -} diff --git a/_example/server_test.go b/_example/server_test.go index 89fb8ae..c960df2 100644 --- a/_example/server_test.go +++ b/_example/server_test.go @@ -3,7 +3,6 @@ package eudore_test import ( "crypto/tls" "crypto/x509" - "io/ioutil" "net" "net/http" "os" @@ -26,13 +25,7 @@ func TestServerStd(t *testing.T) { app.GetFunc("/panic", func(ctx eudore.Context) { panic(400) }) - app.GetFunc("/meta", func(ctx eudore.Context) interface{} { - meta, ok := app.Server.(interface{ Metadata() interface{} }) - if ok { - return meta.Metadata() - } - return nil - }) + app.GetFunc("/meta/*", eudore.HandlerMetadata) app.GetFunc("/err", func(ctx eudore.Context) { var app eudore.App app.CancelFunc() @@ -43,7 +36,7 @@ func TestServerStd(t *testing.T) { app.NewRequest(nil, "GET", "/wrote") app.NewRequest(nil, "GET", "/panic") app.NewRequest(nil, "GET", "/err") - app.NewRequest(nil, "GET", "/meta", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) + app.NewRequest(nil, "GET", "/meta/server") app.CancelFunc() app.Run() @@ -164,7 +157,7 @@ func TestServerMutualTLS(t *testing.T) { func createtp() (*http.Transport, error) { pool := x509.NewCertPool() - data, err := ioutil.ReadFile("/tmp/mca/ca.cer") + data, err := os.ReadFile("/tmp/mca/ca.cer") pool.AppendCertsFromPEM(data) if err != nil { return nil, err @@ -210,7 +203,7 @@ openssl x509 -in server.cer -text -noout 2>&1| head -n 15 */ func createssl() { - ioutil.WriteFile("ca.cer", []byte(`-----BEGIN CERTIFICATE----- + os.WriteFile("ca.cer", []byte(`-----BEGIN CERTIFICATE----- MIIDYjCCAkoCCQDkcCR+EmTT1jANBgkqhkiG9w0BAQUFADBzMQswCQYDVQQGEwJD TjELMAkGA1UECAwCQkoxEDAOBgNVBAcMB2JlaWppbmcxDzANBgNVBAoMBmV1ZG9y ZTEPMA0GA1UECwwGZXVkb3JlMQ8wDQYDVQQLDAZldWRvcmUxEjAQBgNVBAMMCWxv @@ -232,7 +225,7 @@ mjXWmb2mT7j+oZ4P84UFWZsTLgqZaCnY3x+9f7yh00cPGcLwdaC+UjmEsluGwtMG +8lmw2Fo -----END CERTIFICATE----- `), 0644) - ioutil.WriteFile("server.key", []byte(`-----BEGIN RSA PRIVATE KEY----- + os.WriteFile("server.key", []byte(`-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEA0zQ4iUpcInR2++dwEdTH/VGtFDXZEMKqZjdFuMQmqkB8+l7f 2zaTDAweVkaGW7dAWyGQFl/p5GYEFyzdjj5lKhP9bSClxGAFnDjQbP9MI4zKKhtW qWXs1QASUhcM3NpaswUQlu3GtHuxkzZsJv+Y4wvx+6oEwrvqIfh6y+8IsAFHSDHw @@ -260,7 +253,7 @@ d7I+u8v/415SnnwAp4HxpMCe1WgmoMotFPHrA3j9FqmQ74A3TzXV0BvasZK9Pd36 LBPlsSEa0bH2ZjPim2xir8m92MKT3npoe9Qu9U49ozFSNFbx9DeK -----END RSA PRIVATE KEY----- `), 0644) - ioutil.WriteFile("server.cer", []byte(`-----BEGIN CERTIFICATE----- + os.WriteFile("server.cer", []byte(`-----BEGIN CERTIFICATE----- MIIDYjCCAkoCCQDKHAIPMuQDNDANBgkqhkiG9w0BAQUFADBzMQswCQYDVQQGEwJD TjELMAkGA1UECAwCQkoxEDAOBgNVBAcMB2JlaWppbmcxDzANBgNVBAoMBmV1ZG9y ZTEPMA0GA1UECwwGZXVkb3JlMQ8wDQYDVQQLDAZldWRvcmUxEjAQBgNVBAMMCWxv @@ -282,7 +275,7 @@ wL5oxRH7f4om4vsY1Uhe6VUuXh1R05rlCnX9l/HPazaUj2zEGHh/Drxj6BcwSlOr uZwwKuNX -----END CERTIFICATE----- `), 0644) - ioutil.WriteFile("client.key", []byte(`-----BEGIN RSA PRIVATE KEY----- + os.WriteFile("client.key", []byte(`-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEApd+aNQYV4DuWG9Sb9iA6wdEMvRG7tIjBv78Nm9BxZwHaVgwd BEucdX1ieiyDSVJOVxBKjyGmehtE5VH1kAUHzln2TMKYj8cKb78BNmuU+VqV/C2G VbuCtB3wT3IXcI2L4ZvM/9w3INZbxo1oTiTfj2wiyo1ZSXygPuWIk+ardb2/u/ZU @@ -310,7 +303,7 @@ IMoUnpMmCrDHeGZcfAkmZ4XDEXWn4gGIybQuqsJ1znfL6i0S68X/IlLyAuuz8D5Z q9GrgvzRgNgBZxSZagF0Xjk+wjvDiqe4N3kPYu2qxcpg555kaJgX -----END RSA PRIVATE KEY----- `), 0644) - ioutil.WriteFile("client.cer", []byte(`-----BEGIN CERTIFICATE----- + os.WriteFile("client.cer", []byte(`-----BEGIN CERTIFICATE----- MIIDYjCCAkoCCQDKHAIPMuQDNTANBgkqhkiG9w0BAQUFADBzMQswCQYDVQQGEwJD TjELMAkGA1UECAwCQkoxEDAOBgNVBAcMB2JlaWppbmcxDzANBgNVBAoMBmV1ZG9y ZTEPMA0GA1UECwwGZXVkb3JlMQ8wDQYDVQQLDAZldWRvcmUxEjAQBgNVBAMMCWxv diff --git a/_example/util16_test.go b/_example/util16_test.go deleted file mode 100644 index ec046ca..0000000 --- a/_example/util16_test.go +++ /dev/null @@ -1,29 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package eudore_test - -import ( - "embed" - "testing" - "time" - - "github.com/eudore/eudore" -) - -//go:embed *.go -var root embed.FS - -func TestUtilPatch16(t *testing.T) { - eudore.NewHandlerEmbedFunc(root, ".") - eudore.DefaultEmbedTime = time.Now() - - app := eudore.NewApp() - app.GetFunc("/static/*", root) - - app.NewRequest(nil, "GET", "/static/app_test.go") - app.NewRequest(nil, "GET", "/static/none_test.go") - - app.CancelFunc() - app.Run() -} diff --git a/_example/util_test.go b/_example/util_test.go index 5d00d06..be58fba 100644 --- a/_example/util_test.go +++ b/_example/util_test.go @@ -1,120 +1,26 @@ package eudore_test import ( + "context" + "encoding/json" "testing" + "time" "github.com/eudore/eudore" ) -func TestUtilContextKey(*testing.T) { - app := eudore.NewApp() - app.Info(eudore.NewContextKey("debug-key")) - - app.CancelFunc() - app.Run() +func TestUtilContextKey(t *testing.T) { + t.Log(eudore.NewContextKey("debug-key")) } -func TestUtilTimeDuration(*testing.T) { - type Data struct { - Time eudore.TimeDuration `json:"time"` - } - - app := eudore.NewApp() - app.AnyFunc("/time/*", func(ctx eudore.Context) interface{} { - return eudore.TimeDuration(12000000000) - }) - app.AnyFunc("/time/bind", func(ctx eudore.Context) error { - var data Data - ctx.Debug(string(ctx.Body())) - err := ctx.Bind(&data) - ctx.Info(err, data) - return err - }) - - app.NewRequest(nil, "GET", "/time/text", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeText)) - app.NewRequest(nil, "GET", "/time/json", eudore.NewClientHeader(eudore.HeaderAccept, eudore.MimeApplicationJSON)) - app.NewRequest(nil, "PUT", "/time/bind", eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationJSON), eudore.NewClientBodyString(`{"time":"12s"}`)) - app.NewRequest(nil, "PUT", "/time/bind", eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationJSON), eudore.NewClientBodyString(`{"time":12000000000}`)) - app.NewRequest(nil, "PUT", "/time/bind", eudore.NewClientHeader(eudore.HeaderContentType, eudore.MimeApplicationJSON), eudore.NewClientBodyString(`{"time":"x"}`)) - - app.CancelFunc() - app.Run() -} - -func TestUtilGetCast(t *testing.T) { - app := eudore.NewApp() - - app.Debug(eudore.GetBool(int(1))) - app.Debug(eudore.GetBool(uint(1))) - app.Debug(eudore.GetBool(float32(1.0))) - app.Debug(eudore.GetBool("true")) - app.Debug(eudore.GetInt(int(123))) - app.Debug(eudore.GetInt(uint(234))) - app.Debug(eudore.GetInt(float64(345))) - app.Debug(eudore.GetInt("456")) - app.Debug(eudore.GetInt64(int(123))) - app.Debug(eudore.GetInt64(uint(234))) - app.Debug(eudore.GetInt64(float64(345))) - app.Debug(eudore.GetInt64("456")) - app.Debug(eudore.GetUint(int(123))) - app.Debug(eudore.GetUint(uint(234))) - app.Debug(eudore.GetUint(float64(345))) - app.Debug(eudore.GetUint("456")) - app.Debug(eudore.GetUint64(int(123))) - app.Debug(eudore.GetUint64(uint(234))) - app.Debug(eudore.GetUint64(float64(345))) - app.Debug(eudore.GetUint64("456")) - app.Debug(eudore.GetFloat32(int(123))) - app.Debug(eudore.GetFloat32(uint(234))) - app.Debug(eudore.GetFloat32(float64(345))) - app.Debug(eudore.GetFloat32("456")) - app.Debug(eudore.GetFloat64(int(123))) - app.Debug(eudore.GetFloat64(uint(234))) - app.Debug(eudore.GetFloat64(float64(345))) - app.Debug(eudore.GetFloat64("456")) - app.Debug(eudore.GetString(int(123))) - app.Debug(eudore.GetString(uint(234))) - app.Debug(eudore.GetString(float64(345))) - app.Debug(eudore.GetString("456")) - app.Debug(eudore.GetString([]byte("456"))) - app.Debug(eudore.GetString(true)) - app.Debug(eudore.GetString(eudore.NewContextKey("string"))) - app.Debug(eudore.GetBytes("strings")) - app.Debug(eudore.GetStrings("strings")) - app.Debug(eudore.GetStrings([]interface{}{"1", "2", "3"})) - - app.CancelFunc() - app.Run() - -} - -func TestUtilGetCastString(t *testing.T) { - app := eudore.NewApp() - - app.Debug(eudore.GetStringBool("true")) - app.Debug(eudore.GetStringBool("1")) - app.Debug(eudore.GetStringBool("bool")) - app.Debug(eudore.GetStringInt("1")) - app.Debug(eudore.GetStringInt("0", 1)) - app.Debug(eudore.GetStringInt("0", 0)) - app.Debug(eudore.GetStringInt64("1")) - app.Debug(eudore.GetStringInt64("0", 1)) - app.Debug(eudore.GetStringInt64("0", 0)) - app.Debug(eudore.GetStringUint("1")) - app.Debug(eudore.GetStringUint("0", 1)) - app.Debug(eudore.GetStringUint("0", 0)) - app.Debug(eudore.GetStringUint64("1")) - app.Debug(eudore.GetStringUint64("0", 1)) - app.Debug(eudore.GetStringUint64("0", 0)) - app.Debug(eudore.GetStringFloat32("1")) - app.Debug(eudore.GetStringFloat32("0", 1)) - app.Debug(eudore.GetStringFloat32("0", 0)) - app.Debug(eudore.GetStringFloat64("1")) - app.Debug(eudore.GetStringFloat64("0", 1)) - app.Debug(eudore.GetStringFloat64("0", 0)) - - app.CancelFunc() - app.Run() +func TestUtilTimeDuration(t *testing.T) { + var data eudore.TimeDuration + t.Log(json.Unmarshal([]byte(`"12s"`), &data), data) + t.Log(json.Unmarshal([]byte(`12000000000`), &data), data) + t.Log(json.Unmarshal([]byte(`"x"`), &data), data) + b, _ := json.Marshal(data) + t.Log(string(b)) + t.Log(data) } func TestUtilGetWarp(t *testing.T) { @@ -141,50 +47,220 @@ func TestUtilGetWarp(t *testing.T) { } warp := eudore.NewGetWarpWithObject(data) - app.Info("%#v", warp.GetInterface("")) - - app.Info(warp.GetInt("int")) - app.Info(warp.GetInt64("int")) - app.Info(warp.GetUint("int")) - app.Info(warp.GetUint64("int")) - app.Info(warp.GetFloat32("int")) - app.Info(warp.GetFloat64("int")) - app.Info(warp.GetInt("int8")) - app.Info(warp.GetInt64("int8")) - app.Info(warp.GetUint("int8")) - app.Info(warp.GetUint64("int8")) - app.Info(warp.GetFloat32("int8")) - app.Info(warp.GetFloat64("int8")) - - app.Info(warp.GetInt("int1")) - app.Info(warp.GetInt64("int1")) - app.Info(warp.GetUint("int1")) - app.Info(warp.GetUint64("int1")) - app.Info(warp.GetFloat32("int1")) - app.Info(warp.GetFloat64("int1")) - app.Info(warp.GetInt("int1", 3)) - app.Info(warp.GetInt64("int1", 3)) - app.Info(warp.GetUint("int1", 3)) - app.Info(warp.GetUint64("int1", 3)) - app.Info(warp.GetFloat32("int1", 3)) - app.Info(warp.GetFloat64("int1", 3)) - - app.Info(warp.GetString("int")) - app.Info(warp.GetString("string")) - app.Info(warp.GetString("nil")) - app.Info(warp.GetString("bytes")) - app.Info(warp.GetString("int", "default")) - app.Info(warp.GetString("string", "default")) - app.Info(warp.GetString("nil", "default")) - - app.Info(warp.GetBytes("int")) - app.Info(warp.GetBytes("string")) - app.Info(warp.GetBytes("nil")) - app.Info(warp.GetBytes("bytes")) - - app.Info(warp.GetStrings("nil")) - app.Info(warp.GetStrings("string")) - app.Info(warp.GetStrings("arrayint")) - app.Info(warp.GetStrings("arraystr")) - app.Info(warp.GetStrings("arraybyte")) + t.Logf("%#v", warp.GetAny("")) + t.Log(warp.GetInt("int")) + t.Log(warp.GetInt64("int")) + t.Log(warp.GetUint("int")) + t.Log(warp.GetUint64("int")) + t.Log(warp.GetFloat32("int")) + t.Log(warp.GetFloat64("int")) + t.Log(warp.GetString("int")) +} + +func TestUtilGetAnyValue(t *testing.T) { + t.Log(eudore.GetAnyDefault("default", "")) + t.Log(eudore.GetAnyDefault("", "")) + t.Log(eudore.GetAnyDefault("", "default string")) + t.Log(eudore.GetAnyDefaults("default", "")) + t.Log(eudore.GetAnyDefaults("", "")) + t.Log(eudore.GetAnyDefaults("", "default string")) + + t.Log(eudore.GetAny("", "default string")) + t.Log(eudore.GetAny[int](nil)) + t.Log(eudore.GetAny[int](12)) + t.Log(eudore.GetAny[int](uint(12))) + t.Log(eudore.GetAny[int]("12")) + t.Log(eudore.GetAny[string](12)) + t.Log(eudore.GetAny[int64](time.Second)) + + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[string]("string"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[bool]("true"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[bool]("false"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[time.Time]("20180801"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[time.Duration]("200h"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[int]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[int8]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[int16]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[int32]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[int64]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[uint]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[uint8]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[uint16]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[uint32]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[uint64]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[float32]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[float64]("12"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[complex64]("1+2i"))) + t.Log(eudore.GetStringByAny(eudore.GetAnyByString[complex128]("1+2i"))) + t.Log(eudore.GetStringByAny([]byte("bytes"))) + t.Log(eudore.GetStringByAny(eudore.GetStringByAny)) + t.Log(eudore.GetStringByAny("")) + t.Log(eudore.GetStringByAny("", "0")) +} + +func TestUtilGetSetValue(t *testing.T) { + type Field struct { + Index int `alias:"index"` + Name string `alias:"name"` + } + type config struct { + Name string `alias:"name"` + int int `alias:"int` + ano string + Ptr *Field `alias:"ptr"` + Array [4]Field `alias:"array"` + Slice []Field `alias:"slice"` + Map map[int]string `alias:"map"` + Any any `alias:"any"` + Context context.Context `alias:"context"` + *Field + } + + get := func(i any, key string) any { + val, err := eudore.GetAnyByPathWithTag(i, key, nil, false) + if err != nil { + return err + } + return val + } + set := func(i any, key string, val any) error { + return eudore.SetAnyByPath(i, key, val) + } + data := new(config) + + t.Log(eudore.SetAnyByPathWithTag(data, "ano", "ano field", nil, true)) + t.Log(eudore.GetAnyByPathWithTag(data, "ano", nil, true)) + t.Log(eudore.GetAnyByPathWithTag(data, "int", nil, true)) + t.Log(eudore.GetAnyByPathWithValue(data, "int", nil, true)) + t.Log(eudore.GetAnyByPath(data, "int")) + // t.Log(eudore.GetAnyByPath(data, "ptr")) + + get(nil, "") + t.Logf("%#v", get(data, "")) + t.Log(get(data, "ptr.key")) + t.Log(get(data, "name.num")) + t.Log(get(data, "null")) + t.Log(get(data, "int")) + t.Log(get(data, "map.0")) + t.Log(get(data, "slice.0")) + t.Log(get(data, "index")) + + t.Log(set(data, "", 0)) + t.Log(set(*data, "name", 0)) + t.Log(set(data, "name.null", 0)) + t.Log(set(data, "ptr.null", 0)) + t.Log(set(data, "int", 0)) + t.Log(set(data, "context.4", 0)) + t.Log(set(data, "array.x", 0)) + t.Log(set(data, "slice.x", 0)) + t.Log(set(data, "map.xs", 0)) + t.Log(set(data, "index", "x")) + t.Log(set(data, "index", 11)) + + t.Log(set(data, "ptr.index", 12)) + t.Log(set(data, "array.0.index", 13)) + t.Log(set(data, "array.-1.index", 14)) + t.Log(set(data, "slice.5.index", 15)) + t.Log(set(data, "slice.[].index", 16)) + t.Log(set(data, "slice.-1.index", 17)) + t.Log(set(data, "any.8", 18)) + t.Log(set(data, "any.9", 19)) + t.Log(set(data, "map.9", "map9 hello")) + t.Log(set(data, "map.9", "map9 hello")) + + t.Log(get(data, "map.xs")) + t.Log(get(data, "map.0")) + t.Log(get(data, "map.9")) + t.Log(get(data, "array.x.index")) + t.Log(get(data, "array.-1.index")) + t.Log(get(data, "array.0.index")) + t.Log(get(data, "index")) + t.Logf("%#v", get(data, "")) +} + +func TestUtilSetWithValue(t *testing.T) { + type time2 time.Time + type config struct { + Ptr *time.Duration `alias:"ptr"` + Slice []int `alias:"slice"` + Int int `alias:"int"` + Uint uint `alias:"uint"` + Bool bool `alias:"bool"` + Float float64 `alias:"float"` + Complex complex64 `alias:"complex"` + Time time.Time `alias:"time"` + Time2 time2 `alias:"time2"` + Struct struct{} `alias:"struct"` + Bytes []byte `alias:"bytes"` + Runes []rune `alias:"runes"` + Any any `alias:"any"` + Face json.Marshaler `alias:"face"` + Chan chan int `alias:"chan"` + ano string + } + + data := new(config) + eudore.SetAnyByPath(data, "ptr", t) + eudore.SetAnyByPath(data, "ptr", time.Second) + t.Logf("%p", data.Ptr) + d := eudore.TimeDuration(time.Second) + eudore.SetAnyByPath(data, "ptr", &d) + t.Logf("%p %s", data.Ptr, d) + eudore.SetAnyByPath(data, "ptr", "12x") + eudore.SetAnyByPath(data, "ptr", "12s") + t.Logf("%p %s", data.Ptr, d) + + eudore.SetAnyByPath(data, "slice", "12s") + eudore.SetAnyByPath(data, "slice", "12") + eudore.SetAnyByPath(data, "slice", []string{"1", "2", "3"}) + eudore.SetAnyByPath(data, "slice", []string{"a", "x", "c"}) + + eudore.SetAnyByPath(data, "int", "") + eudore.SetAnyByPath(data, "uint", "") + eudore.SetAnyByPath(data, "bool", "") + eudore.SetAnyByPath(data, "float", "") + eudore.SetAnyByPath(data, "complex", "") + eudore.SetAnyByPath(data, "complex", "0+x") + eudore.SetAnyByPath(data, "complex", "0i+x") + t.Log(eudore.SetAnyByPath(data, "time", "2018")) + t.Log(eudore.SetAnyByPath(data, "chan", "2018")) + + t.Log(eudore.SetAnyByPath(data, "int", "1")) + t.Log(eudore.SetAnyByPath(data, "uint", "1")) + t.Log(eudore.SetAnyByPath(data, "bool", "1")) + t.Log(eudore.SetAnyByPath(data, "float", "1")) + t.Log(eudore.SetAnyByPath(data, "complex", "1i")) + t.Log(eudore.SetAnyByPath(data, "time", "20180801")) + t.Log(eudore.SetAnyByPath(data, "time2", "20180801")) + t.Log(eudore.SetAnyByPath(data, "bytes", "bytes")) + t.Log(eudore.SetAnyByPath(data, "runes", "runes")) + t.Log(eudore.SetAnyByPath(data, "any", "any")) + t.Log(eudore.SetAnyByPath(data, "face", "any")) + t.Log(eudore.SetAnyByPath(data, "struct", "struct")) + t.Log(eudore.SetAnyByPathWithTag(data, "ano", time.Now(), nil, true)) + t.Logf("%#v", eudore.GetAnyByPath(data, "")) + + type M struct { + M1 map[string]any `alias:"m1"` + M2 map[*string]any `alias:"m2"` + M3 map[eudore.LoggerLevel]any `alias:"m3"` + M4 map[eudore.TimeDuration]any `alias:"m4"` + M5 map[any]any `alias:"m5"` + } + + m := &M{} + t.Log(eudore.SetAnyByPath(m, "m1.1", "1")) + t.Log(eudore.SetAnyByPath(m, "m2.3", "1")) + t.Log(eudore.SetAnyByPath(m, "m3.ERROR", "1")) + t.Log(eudore.SetAnyByPath(m, "m4.4s", "1")) + t.Log(eudore.SetAnyByPath(m, "m5.5", "1")) + t.Logf("%#v", m) + + type Cycle struct { + *Cycle + } + c := &Cycle{} + c.Cycle = c + t.Log(eudore.SetAnyByPathWithTag(c, "name", "eudore", nil, false)) + t.Log(eudore.GetAnyByPathWithTag(c, "name", nil, false)) } diff --git a/app.go b/app.go index 992c57b..8ffd74e 100644 --- a/app.go +++ b/app.go @@ -3,13 +3,9 @@ Package eudore golang http framework, less is more. source: https://github.com/eudore/eudore -document: https://www.eudore.cn - -exapmle: https://github.com/eudore/eudore/tree/master/_example - wiki: https://github.com/eudore/eudore/wiki -godoc: https://godoc.org/github.com/eudore/eudore +exapmle: https://github.com/eudore/eudore/tree/master/_example godev: https://pkg.go.dev/github.com/eudore/eudore */ @@ -19,32 +15,25 @@ package eudore // import "github.com/eudore/eudore" import ( "context" + "errors" "fmt" + "html/template" "net" "net/http" "sync" ) /* -App combines the main functional interfaces and only implements simple basic methods. +The App struct is defined as the main object for the application, +which combines various functional interfaces and implements basic methods. +It provides additional features such as -The following functions are realized in addition to the functions of the combined components: Manage Object Lifecycle Store global data Register global middleware Start port monitoring Block running service Get configuration value and convert type - -App 组合主要功能接口,本身仅实现简单的基本方法。 - -组合各组件功能外实现下列功能: - 管理对象生命周期 - 存储全局数据 - 注册全局中间件 - 启动端口监听 - 阻塞运行服务 - 获取配置值并转换类型 */ type App struct { context.Context `alias:"context"` @@ -59,29 +48,30 @@ type App struct { HandlerFuncs HandlerFuncs `alias:"handlerfuncs"` ContextPool *sync.Pool `alias:"contextpool"` CancelError error `alias:"cancelerror"` - cancelMutex sync.Mutex - Values []interface{} + cancelMutex sync.Mutex `alias:"cancelmutex"` + Values []any `alias:"values"` } -// NewApp function creates an App object. -// -// NewApp 函数创建一个App对象。 +// The NewApp() function creates an App object, initializes various components of the application, and returns the App object. func NewApp() *App { app := &App{} app.GetWarp = NewGetWarpWithApp(app) app.HandlerFuncs = HandlerFuncs{app.serveContext} app.Context, app.CancelFunc = context.WithCancel(context.Background()) - app.SetValue(ContextKeyLogger, NewLoggerStd(nil)) - app.SetValue(ContextKeyConfig, NewConfigStd(nil)) - app.SetValue(ContextKeyDatabase, NewDatabaseStd(nil)) - app.SetValue(ContextKeyClient, NewClientStd()) - app.SetValue(ContextKeyServer, NewServerStd(nil)) - app.SetValue(ContextKeyRouter, NewRouterStd(nil)) + app.SetValue(ContextKeyLogger, NewLogger(nil)) + app.SetValue(ContextKeyConfig, NewConfig(nil)) + app.SetValue(ContextKeyDatabase, NewDatabase(nil)) + app.SetValue(ContextKeyClient, NewClient()) + app.SetValue(ContextKeyServer, NewServer(nil)) + app.SetValue(ContextKeyRouter, NewRouter(nil)) + app.SetValue(ContextKeyBind, NewBinds(nil)) + app.SetValue(ContextKeyRender, NewRenders(nil)) + app.SetValue(ContextKeyTemplate, template.Must(template.New("").Parse(DefaultTemplateInit))) app.ContextPool = NewContextBasePool(app) return app } -// Run method starts the App to block and wait for the App to end. +// The Run() method starts the application and blocks it until it is finished. // // Run 方法启动App阻塞等待App结束。 func (app *App) Run() error { @@ -95,10 +85,12 @@ func (app *App) Run() error { for i := len(app.Values) - 2; i > -1; i -= 2 { app.SetValue(app.Values[i], nil) } - if app.Err() == context.Canceled { - app.Info("eudore app cannel context") + + log := app.WithField(ParamDepth, 2) + if errors.Is(app.Err(), context.Canceled) { + log.Info("eudore app", app.Err()) } else { - app.Fatal("eudore app cannel context error:", app.Err()) + log.Fatal("eudore app error:", app.Err()) } }() <-app.Done() @@ -110,26 +102,26 @@ func (app *App) Run() error { // this method is automatically called when setting and unsetting. // // SetValue 方法从App设置指定键值,如果值实现Mount/Unmount方法在设置和取消设置时自动调用该方法。 -func (app *App) SetValue(key, val interface{}) { - withMount(app, val) +func (app *App) SetValue(key, val any) { + anyMount(app, val) switch key { case ContextKeyLogger: - defer withUnmount(app, app.Logger) + defer anyUnmount(app, app.Logger) app.Logger, _ = val.(Logger) case ContextKeyConfig: - defer withUnmount(app, app.Config) + defer anyUnmount(app, app.Config) app.Config, _ = val.(Config) case ContextKeyDatabase: - defer withUnmount(app, app.Database) + defer anyUnmount(app, app.Database) app.Database, _ = val.(Database) case ContextKeyClient: - defer withUnmount(app, app.Client) + defer anyUnmount(app, app.Client) app.Client, _ = val.(Client) case ContextKeyServer: - defer withUnmount(app, app.Server) + defer anyUnmount(app, app.Server) app.Server, _ = val.(Server) case ContextKeyRouter: - defer withUnmount(app, app.Router) + defer anyUnmount(app, app.Router) app.Router, _ = val.(Router) case ContextKeyContextPool: app.ContextPool, _ = val.(*sync.Pool) @@ -147,7 +139,7 @@ func (app *App) SetValue(key, val interface{}) { default: for i := 0; i < len(app.Values); i += 2 { if app.Values[i] == key { - defer withUnmount(app, app.Values[i+1]) + defer anyUnmount(app, app.Values[i+1]) app.Values[i+1] = val return } @@ -159,7 +151,7 @@ func (app *App) SetValue(key, val interface{}) { // Value method gets the specified key value from the App. // // Value 方法从App获取指定键值。 -func (app *App) Value(key interface{}) interface{} { +func (app *App) Value(key any) any { switch key { case ContextKeyApp: return app @@ -175,6 +167,16 @@ func (app *App) Value(key interface{}) interface{} { return app.Server case ContextKeyRouter: return app.Router + case ContextKeyAppKeys: + keys := make([]any, 0, 6+len(app.Values)/2) + keys = append(keys, + ContextKeyLogger, ContextKeyConfig, ContextKeyDatabase, + ContextKeyClient, ContextKeyServer, ContextKeyRouter, + ) + for i := 0; i < len(app.Values); i += 2 { + keys = append(keys, app.Values[i]) + } + return keys } for i := 0; i < len(app.Values); i += 2 { if app.Values[i] == key { @@ -196,22 +198,22 @@ func (app *App) Err() error { return app.Context.Err() } -func withMount(ctx context.Context, i interface{}) { +func anyMount(ctx context.Context, i any) { loader, ok := i.(interface{ Mount(context.Context) }) if ok { loader.Mount(ctx) } } -func withUnmount(ctx context.Context, i interface{}) { +func anyUnmount(ctx context.Context, i any) { closer, ok := i.(interface{ Unmount(context.Context) }) if ok { closer.Unmount(ctx) } } -func withMetadata(i interface{}) interface{} { - metaer, ok := i.(interface{ Metadata() interface{} }) +func anyMetadata(i any) any { + metaer, ok := i.(interface{ Metadata() any }) if ok { return metaer.Metadata() } @@ -248,18 +250,19 @@ func (app *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { // AddMiddleware If the first parameter of the AddMiddleware method is the string "global", // it will be added to the App as a global request middleware, -// using DefaultHandlerExtend to create a request processing function, +// using NewHandlerExtenderWithContext to create a request processing function, // otherwise it is equivalent to calling the app.Rputer.AddMiddleware method. // // AddMiddleware 方法如果第一个参数为字符串"global", -// 为全局请求中间件添加给App(使用DefaultHandlerExtend创建请求处理函数), +// 为全局请求中间件添加给App(使用NewHandlerExtenderWithContext创建请求处理函数), // 否则等同于调用app.Rputer.AddMiddleware方法。 -func (app *App) AddMiddleware(hs ...interface{}) error { +func (app *App) AddMiddleware(hs ...any) error { if len(hs) > 1 { name, ok := hs[0].(string) if ok && name == "global" { - handler := DefaultHandlerExtend.NewHandlerFuncs("", hs[1:]) - app.WithField("depth", 1).Info("Register app global middleware:", handler) + handler := NewHandlerExtenderWithContext(app).CreateHandler("", hs[1:]) + app.WithField(ParamDepth, 1). + Info("Register app global middleware:", handler) last := app.HandlerFuncs[len(app.HandlerFuncs)-1] app.HandlerFuncs = NewHandlerFuncsCombine(app.HandlerFuncs[0:len(app.HandlerFuncs)-1], handler) app.HandlerFuncs = NewHandlerFuncsCombine(app.HandlerFuncs, HandlerFuncs{last}) @@ -281,7 +284,8 @@ func (app *App) Listen(addr string) error { app.Error(err) return err } - app.Logger.WithField("depth", 1).Infof("listen http in %s %s", ln.Addr().Network(), ln.Addr().String()) + app.WithField(ParamDepth, 1).Infof("listen http in %s %s", + ln.Addr().Network(), ln.Addr().String()) app.Serve(ln) return nil } @@ -302,7 +306,7 @@ func (app *App) ListenTLS(addr, key, cert string) error { app.Error(err) return err } - app.Logger.WithField("depth", 1).Infof("listen https in %s %s,host name: %v", + app.WithField(ParamDepth, 1).Infof("listen https in %s %s, host name: %v", ln.Addr().Network(), ln.Addr().String(), conf.Certificate.DNSNames) app.Serve(ln) return nil @@ -317,3 +321,11 @@ func (app *App) Serve(ln net.Listener) { app.SetValue(ContextKeyError, srv.Serve(ln)) }() } + +func (app *App) Parse() error { + err := app.Config.Parse(app) + if err != nil { + app.SetValue(ContextKeyError, err) + } + return err +} diff --git a/client.go b/client.go index 1b20010..46b2e09 100644 --- a/client.go +++ b/client.go @@ -3,51 +3,62 @@ package eudore import ( "bytes" "context" - "crypto/rand" - "crypto/sha256" - "crypto/tls" - "encoding/hex" "encoding/json" "encoding/xml" + "errors" "fmt" "io" - "io/ioutil" "mime/multipart" "net" "net/http" - "net/http/httptrace" "net/url" "os" "path/filepath" - "reflect" "strings" "time" ) // Client 定义http客户端接口,构建并发送http请求。 type Client interface { - NewRequest(context.Context, string, string, ...interface{}) error - WithClient(...interface{}) Client + NewRequest(context.Context, string, string, ...any) error + WithClient(...any) Client GetClient() *http.Client } -// ClientRequestOption 定义http请求选项。 -type ClientRequestOption func(*http.Request) - -// ClientResponseOption 定义http响应选项。 -type ClientResponseOption func(*http.Response) error - // clientStd 定义http客户端默认实现。 type clientStd struct { - Context context.Context - Client *http.Client - Options []interface{} + Client *http.Client `alias:"client"` + Option *ClientOption `alias:"option"` +} + +/* +ClientBody defines the client Body. + +The GetBody method returns a shallow copy of the data for request redirection and retry. + +The AddValue method sets the data saved by the body. + +The AddFile method can add file upload when using MultipartForm. + +ClientBody 定义客户端Body。 + +GetBody方法返回数据浅复制用于请求重定向和重试。 + +AddValue方法设置body保存的数据。 + +AddFile方法在MultipartForm时可以添加文件上传。 +*/ +type ClientBody interface { + io.ReadCloser + GetContentType() string + GetBody() (io.ReadCloser, error) + AddValue(string, any) + AddFile(string, string, any) } -// NewClientStd 函数创建默认http客户端实现,参数为默认选项。 -func NewClientStd(options ...interface{}) Client { +// NewClient 函数创建默认http客户端实现,参数为默认选项。 +func NewClient(options ...any) Client { return &clientStd{ - Context: context.Background(), Client: &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -58,16 +69,17 @@ func NewClientStd(options ...interface{}) Client { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, }, + Timeout: DefaultClientTimeout, }, - Options: options, + Option: NewClientOption(context.Background(), options), } } // newDialContext 函数创建http客户端Dial函数,如果是内部请求Host,从环境上下文获取到Server处理连接。 func newDialContext() func(ctx context.Context, network, addr string) (net.Conn, error) { fn := (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, + Timeout: DefaultClientDialTimeout, + KeepAlive: DefaultClientDialKeepAlive, }).DialContext return func(ctx context.Context, network, addr string) (net.Conn, error) { if network == "tcp" && addr == DefaultClientInternalHost { @@ -83,37 +95,137 @@ func newDialContext() func(ctx context.Context, network, addr string) (net.Conn, } } -// Mount 方法保存context.Context作为Client默认发起请求的context.Context +// Mount 方法保存context.Context作为Client默认发起请求的context.Context。 func (client *clientStd) Mount(ctx context.Context) { - client.Context = ctx + client.Option.Context = ctx +} + +// WithClient 方法给客户端追加新的选项,返回客户端深拷贝。 +/* + Timeout + http.CookieJar + *http.Transport +*/ +func (client *clientStd) WithClient(options ...any) Client { + c := client.Client + if canCopyClient(options) { + c = &http.Client{} + *c = *client.Client + } + + for i := range options { + switch o := options[i].(type) { + case *http.Client: + c = o + case *http.Transport: + tp, ok := c.Transport.(*http.Transport) + if ok { + SetAnyDefault(tp, o) + } else { + c.Transport = o + } + case http.RoundTripper: + c.Transport = o + case http.CookieJar: + c.Jar = o + case time.Duration: + c.Timeout = o + } + } + + return &clientStd{ + Client: c, + Option: client.Option.clone().appendOptions(client.Option.Context, options), + } +} + +func canCopyClient(options []any) bool { + for i := range options { + switch options[i].(type) { + case *http.Client, *http.Transport, http.RoundTripper, http.CookieJar, time.Duration: + return true + } + } + return false +} + +// GetClient 方法返回*http.Client对象,用于修改属性。 +func (client *clientStd) GetClient() *http.Client { + return client.Client } // NewRequest 方法发送http请求。 -func (client *clientStd) NewRequest(ctx context.Context, method string, path string, options ...interface{}) error { - if ctx == nil { - ctx = client.Context +func (client *clientStd) NewRequest(ctx context.Context, method string, path string, options ...any) error { + option := client.Option.clone().appendOptions(ctx, options) + path = initRequestPath(option.Context, path) + if option.Trace != nil { + option.Context = NewClientTraceWithContext(option.Context, option.Trace) + } + + ctx = option.Context + if option.Retrys == nil && option.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(option.Context, option.Timeout) + defer cancel() } - req, err := http.NewRequestWithContext(ctx, method, initRequestPath(ctx, path), nil) + req, err := http.NewRequestWithContext(ctx, method, path, option.Body) if err != nil { return err } - ro, wo := initRequestOptions(req, append(client.Options, options...)) - for i := range ro { - ro[i](req) + option.apply(req) + + resp, err := client.dotry(req, option) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() } + return option.release(req, resp, err) +} - resp, err := client.Client.Do(req) - if err != nil { - return err +func (client *clientStd) dotry(req *http.Request, option *ClientOption) (*http.Response, error) { + if option.Retrys == nil { + return client.Client.Do(req) } - for i := range wo { - err = wo[i](resp) - if err != nil { - return err + + attempts := make([]int, len(option.Retrys)) + for { + r := req + // retry set timeout + if option.Timeout > 0 { + ctx, cancel := context.WithTimeout(option.Context, option.Timeout) + defer cancel() + r = req.WithContext(ctx) + } + + resp, err := client.Client.Do(r) + if err == nil && resp.StatusCode < StatusTooManyRequests && resp.StatusCode != StatusUnauthorized { + return resp, err + } + + // If body has been sent + if resp != nil && req.Body != nil { + if req.GetBody == nil { + return resp, err + } + body, err2 := req.GetBody() + if err2 != nil { + return resp, err + } + req.Body = body + } + + notry := true + for i, retry := range option.Retrys { + if attempts[i] < retry.Max && retry.Condition(attempts[i], resp, err) { + attempts[i]++ + notry = false + break + } + } + if notry { + return resp, err } } - return nil } // initRequestPath 函数初始化请求url,如果Host为空设置默认或内部Host,如果请求协议为空设置为http。 @@ -137,380 +249,283 @@ func initRequestPath(ctx context.Context, path string) string { return u.String() } -// initRequestOptions 函数初始化请求选项,并返回全部ClientRequestOption和ClientResponseOption。 -func initRequestOptions(r *http.Request, options []interface{}) (ro []ClientRequestOption, wo []ClientResponseOption) { - for i := range options { - switch option := options[i].(type) { - case io.ReadCloser: - r.Body = option - r.GetBody = initGetBody(r.Body) - case io.Reader: - r.Body = ioutil.NopCloser(option) - r.GetBody = initGetBody(r.Body) - case string: - ro = append(ro, NewClientBodyString(option)) - case []byte: - ro = append(ro, NewClientBodyString(string(option))) - case http.Header: - headerCopy(r.Header, option) - case *http.Cookie: - r.AddCookie(option) - case url.Values: - v, err := url.ParseQuery(r.URL.RawQuery) - if err == nil { - headerCopy(v, option) - r.URL.RawQuery = v.Encode() - } - case ClientRequestOption: - ro = append(ro, option) - case func(*http.Request): - ro = append(ro, option) - case ClientResponseOption: - wo = append(wo, option) - case func(*http.Response) error: - wo = append(wo, option) - default: - ro = append(ro, NewClientBody(option)) - } - } - return +type bodyDecoder struct { + Reader io.ReadCloser + Values map[string]any + Data any + Type string + Encoder func(io.Writer, any) } -// WithClient 方法给客户端追加新的选项,返回客户端深拷贝。 -func (client *clientStd) WithClient(options ...interface{}) Client { - options = append(client.Options, options...) - return &clientStd{ - Context: client.Context, - Client: client.Client, - Options: options, - } +// NewClientBodyJSON 函数创建一个json编码器。 +func NewClientBodyJSON(data any) ClientBody { + return NewClientBodyDecoder(MimeApplicationJSON, data, func(w io.Writer, data any) { + json.NewEncoder(w).Encode(data) + }) } -// GetClient 方法返回*http.Client对象,用于修改属性。 -func (client *clientStd) GetClient() *http.Client { - return client.Client +// NewClientBodyXML 函数创建一个xml编码器。 +func NewClientBodyXML(data any) ClientBody { + return NewClientBodyDecoder(MimeApplicationXML, data, func(w io.Writer, data any) { + xml.NewEncoder(w).Encode(data) + }) } -// NewClientHost 数创建请求选项修改Host。 -func NewClientHost(host string) ClientRequestOption { - return func(r *http.Request) { - r.Host = host - r.Header.Set("Host", host) - } +// NewClientBodyProtobuf 函数创建一个protobuf编码器。 +func NewClientBodyProtobuf(data any) ClientBody { + return NewClientBodyDecoder(MimeApplicationProtobuf, data, func(w io.Writer, data any) { + NewProtobufEncoder(w).Encode(data) + }) } -// NewClientQuery 函数创建请求选项追加请求参数。 -func NewClientQuery(key, val string) ClientRequestOption { - return func(r *http.Request) { - v, err := url.ParseQuery(r.URL.RawQuery) - if err == nil { - v.Add(key, val) - r.URL.RawQuery = v.Encode() - } +// The NewClientBodyDecoder function creates a ClientBody encoder, +// which needs to specify contenttype and encoder. +// +// NewClientBodyDecoder 函数创建一个ClientBody编码器,需要指定contenttype和encoder。 +func NewClientBodyDecoder(contenttype string, data any, encoder func(io.Writer, any)) ClientBody { + if data == nil { + data = make(map[string]any) + } + vals, _ := data.(map[string]any) + body := &bodyDecoder{ + Data: data, + Values: vals, + Type: contenttype, + Encoder: encoder, } + return body } -// NewClientQuerys 函数创建请求选项加请求参数。 -func NewClientQuerys(querys url.Values) ClientRequestOption { - return func(r *http.Request) { - v, err := url.ParseQuery(r.URL.RawQuery) - if err == nil { - headerCopy(v, querys) - r.URL.RawQuery = v.Encode() - } +func (body *bodyDecoder) Read(p []byte) (int, error) { + if body.Reader == nil { + rc, wc := io.Pipe() + body.Reader = rc + go func() { + body.Encoder(wc, body.Data) + wc.Close() + }() } + return body.Reader.Read(p) } -// NewClientHeader 函数创建请求选项追加Header。 -func NewClientHeader(key, val string) ClientRequestOption { - return func(r *http.Request) { - r.Header.Add(key, val) +func (body *bodyDecoder) Close() error { + if body.Reader != nil { + return body.Reader.Close() } + return nil } -// NewClientHeaders 函数创建请求选项追加Header。 -func NewClientHeaders(headers http.Header) ClientRequestOption { - return func(r *http.Request) { - headerCopy(r.Header, headers) - } +func (body *bodyDecoder) GetContentType() string { + return body.Type } -// NewClientCookie 函数创建请求选项追加请求Cookie -func NewClientCookie(key, val string) ClientRequestOption { - return func(r *http.Request) { - r.AddCookie(&http.Cookie{Name: key, Value: val}) - } +func (body *bodyDecoder) GetBody() (io.ReadCloser, error) { + return &bodyDecoder{ + Data: body.Data, + Values: body.Values, + Type: body.Type, + Encoder: body.Encoder, + }, nil } -// NewClientBasicAuth 函数创建请求选项追加BasicAuth用户信息。 -func NewClientBasicAuth(username, password string) ClientRequestOption { - return func(r *http.Request) { - r.SetBasicAuth(username, password) +func (body *bodyDecoder) AddValue(key string, val any) { + if body.Values != nil { + body.Values[key] = val + } else { + SetAnyByPath(body.Data, key, val) } } -// NewClientBody 方法追加请求Body字符串。 -func NewClientBody(data interface{}) ClientRequestOption { - return func(r *http.Request) { - var contenttype string - switch reflect.Indirect(reflect.ValueOf(data)).Kind() { - case reflect.Struct: - contenttype = r.Header.Get(HeaderContentType) - if contenttype == "" { - contenttype = DefaultClientBodyContextType - } - case reflect.Slice, reflect.Map: - contenttype = MimeApplicationJSON - default: - return - } - switch contenttype { - case MimeApplicationJSON, MimeApplicationJSONCharsetUtf8: - r.Body = &bodyEncoder{data: data, contenttype: MimeApplicationJSON} - case MimeApplicationXML, MimeApplicationXMLCharsetUtf8: - r.Body = &bodyEncoder{data: data, contenttype: MimeApplicationXML} - case MimeApplicationProtobuf: - r.Body = &bodyEncoder{data: data, contenttype: MimeApplicationProtobuf} - default: - return - } - r.GetBody = initGetBody(r.Body) - } +func (body *bodyDecoder) AddFile(string, string, any) {} + +type bodyForm struct { + Reader io.ReadCloser + Values url.Values + Files map[string][]fileContent + Boundary string + NoClone bool } -// NewClientBodyString 方法追加请求Body字符串。 -func NewClientBodyString(str string) ClientRequestOption { - return func(r *http.Request) { - if r.Body == nil { - r.Body = &bodyBuffer{} - r.GetBody = initGetBody(r.Body) - } - body, ok := r.Body.(*bodyBuffer) - if ok { - body.WriteString(str) - r.ContentLength = int64(body.Len()) - } - } +type fileContent struct { + Name string + Body []byte + File string + Reader io.Reader } -// NewClientBodyJSON 方法追加请求json值或json对象。 -func NewClientBodyJSON(data interface{}) ClientRequestOption { - return func(r *http.Request) { - r.ContentLength = -1 - r.Header.Add(HeaderContentType, "application/json") - body, ok := data.(map[string]interface{}) - if ok { - r.Body = &bodyJSON{values: body} +// NewClientBodyForm 函数创建ApplicationForm或MultipartForm请求body。 +// +// AddFile方法允许data类型为[]byte io.Reader;如果类型为string则加载这个本地文件。 +// +// 如果使用AddFile方法添加文件ContentType为MultipartForm。 +func NewClientBodyForm(data url.Values) ClientBody { + return &bodyForm{Values: data, Boundary: GetStringRandom(30)} +} + +func (body *bodyForm) Read(p []byte) (n int, err error) { + if body.Reader == nil { + if body.Files == nil { + body.Reader = io.NopCloser(strings.NewReader(body.Values.Encode())) } else { - r.Body = &bodyJSON{data: data} + rc, wc := io.Pipe() + body.Reader = rc + body.encode(wc) } - r.GetBody = initGetBody(r.Body) } + return body.Reader.Read(p) } -// NewClientBodyJSONValue 方法追加请求json值。 -func NewClientBodyJSONValue(key string, val interface{}) ClientRequestOption { - return func(r *http.Request) { - if r.Body == nil && r.Header.Get(HeaderContentType) == "" { - r.ContentLength = -1 - r.Header.Add(HeaderContentType, "application/json") - r.Body = &bodyJSON{values: make(map[string]interface{})} - r.GetBody = initGetBody(r.Body) +func (body *bodyForm) encode(wc io.WriteCloser) { + w := multipart.NewWriter(wc) + w.SetBoundary(body.Boundary) + go func() { + for key, vals := range body.Values { + for _, val := range vals { + w.WriteField(key, val) + } } - body, ok := r.Body.(*bodyJSON) - if ok { - body.values[key] = val + for key, vals := range body.Files { + for _, val := range vals { + part, _ := w.CreateFormFile(key, val.Name) + switch { + case val.Body != nil: + part.Write(val.Body) + case val.Reader != nil: + io.Copy(part, val.Reader) + c, ok := val.Reader.(io.Closer) + if ok { + c.Close() + } + case val.File != "": + file, err := os.Open(val.File) + if err == nil { + io.Copy(part, file) + file.Close() + } + } + } } - } + w.Close() + wc.Close() + }() } -// NewClientBodyFormValue 方法追加请求Form值。 -func NewClientBodyFormValue(key string, val string) ClientRequestOption { - return func(r *http.Request) { - initBodyForm(r) - body, ok := r.Body.(*bodyForm) - if ok { - body.Values[key] = append(body.Values[key], val) - } +func (body *bodyForm) Close() error { + if body.Reader != nil { + return body.Reader.Close() } + return nil } -// NewClientBodyFormValues 方法追加请求Form值。 -func NewClientBodyFormValues(data map[string]string) ClientRequestOption { - return func(r *http.Request) { - initBodyForm(r) - body, ok := r.Body.(*bodyForm) - if ok { - for key, val := range data { - body.Values[key] = append(body.Values[key], val) - } - } +func (body *bodyForm) GetContentType() string { + if body.Files == nil { + return MimeApplicationForm } + + return "multipart/form-data; boundary=" + body.Boundary } -// NewClientBodyFormFile 方法给form添加文件内容,文件类型可以为[]byte string io.ReadCloser io.Reader。 -func NewClientBodyFormFile(key, name string, val interface{}) ClientRequestOption { - return func(r *http.Request) { - initBodyForm(r) - body, ok := r.Body.(*bodyForm) - if ok { - var content fileContent - switch body := val.(type) { - case []byte: - content.Body = body - case string: - content.Body = []byte(body) - case io.ReadCloser: - content.Reader = body - case io.Reader: - content.Reader = ioutil.NopCloser(body) - default: - return - } - content.Name = name - body.Files[key] = append(body.Files[key], content) - } +func (body *bodyForm) GetBody() (io.ReadCloser, error) { + if body.NoClone { + return nil, ErrClientBodyFormNotGetBody } + return &bodyForm{ + Values: body.Values, + Files: body.Files, + Boundary: body.Boundary, + }, nil } -// NewClientBodyFormLocalFile 方法给form添加本地文件内容。 -func NewClientBodyFormLocalFile(key, name, path string) ClientRequestOption { - return func(r *http.Request) { - initBodyForm(r) - body, ok := r.Body.(*bodyForm) - if ok { - if name == "" { - name = filepath.Base(path) - } - body.Files[key] = append(body.Files[key], fileContent{ - Name: name, - File: path, - }) - } +func (body *bodyForm) AddValue(key string, val any) { + if body.Values == nil { + body.Values = make(url.Values) } + body.Values.Add(key, GetStringByAny(val)) } -// initBodyForm -func initBodyForm(r *http.Request) { - if r.Body == nil && r.Header.Get(HeaderContentType) == "" { - var buf [30]byte - io.ReadFull(rand.Reader, buf[:]) - boundary := fmt.Sprintf("%x", buf[:]) - r.ContentLength = -1 - r.Header.Add(HeaderContentType, "multipart/form-data; boundary="+boundary) - r.Body = &bodyForm{ - Boundary: boundary, - Values: make(map[string][]string), - Files: make(map[string][]fileContent), - } - r.GetBody = initGetBody(r.Body) +func (body *bodyForm) AddFile(key string, name string, data any) { + if body.Files == nil { + body.Files = make(map[string][]fileContent) } -} -// initGetBody 函数创建http.GetBody函数,如果body实现Clone() io.ReadCloser方法,在调用GetBody时使用。 -func initGetBody(data interface{}) func() (io.ReadCloser, error) { - return func() (io.ReadCloser, error) { - body, ok := data.(interface{ Clone() io.ReadCloser }) - if ok { - return body.Clone(), nil + content := fileContent{Name: name} + switch b := data.(type) { + case []byte: + content.Body = b + case string: + if name == "" { + content.Name = filepath.Base(b) } - return http.NoBody, nil + content.File = b + case io.Reader: + body.NoClone = true + content.Reader = b + default: + return } + body.Files[key] = append(body.Files[key], content) } -// NewClientTimeout 函数创建请求选项设置请求超时时间。 -func NewClientTimeout(timeout time.Duration) ClientRequestOption { - return func(r *http.Request) { - ctx, _ := context.WithTimeout(r.Context(), timeout) - *r = *r.WithContext(ctx) +// NewClientCheckStatus 方法创建响应选项检查响应状态码。 +func NewClientCheckStatus(status ...int) func(*http.Response) error { + return func(w *http.Response) error { + for i := range status { + if status[i] == w.StatusCode { + return nil + } + } + + return fmt.Errorf(ErrFormatClintCheckStatusError, w.StatusCode, status) } } -// NewClientTrace 函数创建请求选项在请求上下文保存ClientTrace对象和httptrace.ClientTrace,实现http客户端追踪。 -func NewClientTrace() ClientRequestOption { - return func(r *http.Request) { - trace := &ClientTrace{ - HTTPStart: time.Now(), - WroteHeaders: make(http.Header), +// NewClienProxyWriter 函数将客户端响应写入另外Writer, +// +// 如果Writer实现http.ResponseWriter接口会写入状态码和Header。 +func NewClienProxyWriter(writer io.Writer) func(*http.Response) error { + return func(w *http.Response) error { + wr, ok := writer.(http.ResponseWriter) + if ok { + wr.WriteHeader(w.StatusCode) + h := w.Header.Clone() + for _, key := range DefaultClinetHopHeaders { + h.Del(key) + } + for key, vals := range h { + for _, val := range vals { + wr.Header().Add(key, val) + } + } } - ctx := context.WithValue(r.Context(), ContextKeyClientTrace, trace) - ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ - DNSStart: func(info httptrace.DNSStartInfo) { trace.DNSStart = time.Now(); trace.DNSHost = info.Host }, - DNSDone: func(info httptrace.DNSDoneInfo) { trace.DNSDone = time.Now(); trace.DNSAddrs = info.Addrs }, - ConnectStart: func(network, addr string) { - trace.ConnectStart = time.Now() - trace.ConnectNetwork = network - trace.ConnectAddress = addr - }, - ConnectDone: func(string, string, error) { trace.ConnectDone = time.Now() }, - GetConn: func(hostPort string) { trace.GetConn = time.Now(); trace.GetConnHostPort = hostPort }, - GotConn: func(httptrace.GotConnInfo) { trace.GotConn = time.Now() }, - GotFirstResponseByte: func() { trace.GotFirstResponseByte = time.Now() }, - TLSHandshakeStart: func() { trace.TLSHandshakeStart = time.Now() }, - TLSHandshakeDone: func(state tls.ConnectionState, _ error) { - trace.TLSHandshakeDone = time.Now() - trace.TLSHandshakeState = &state - }, - WroteHeaderField: func(key string, value []string) { trace.WroteHeaders[key] = value }, - }) - *r = *r.WithContext(ctx) - } -} - -// ClientTrace 定义http客户端请求追踪记录的数据 -type ClientTrace struct { - HTTPStart time.Time `json:"http-start" xml:"http-start"` - HTTPDone time.Time `json:"http-done" xml:"http-done"` - HTTPDuration time.Duration `json:"http-duration" xml:"http-duration"` - DNSStart time.Time `json:"dns-start,omitempty" xml:"dns-start,omitempty"` - DNSDone time.Time `json:"dns-done,omitempty" xml:"dns-done,omitempty"` - DNSDuration time.Duration `json:"dns-duration,omitempty" xml:"dns-duration,omitempty"` - DNSHost string `json:"dns-host,omitempty" xml:"dns-host,omitempty"` - DNSAddrs []net.IPAddr `json:"dns-addrs,omitempty" xml:"dns-addrs,omitempty"` - ConnectStart time.Time `json:"connect-start,omitempty" xml:"connect-start,omitempty"` - ConnectDone time.Time `json:"connect-done,omitempty" xml:"connect-done,omitempty"` - ConnectDuration time.Duration `json:"connect-duration,omitempty" xml:"connect-duration,omitempty"` - ConnectNetwork string `json:"connect-network,omitempty" xml:"connect-network,omitempty"` - ConnectAddress string `json:"connect-address,omitempty" xml:"connect-address,omitempty"` - GetConn time.Time `json:"get-conn" xml:"get-conn"` - GetConnHostPort string `json:"get-conn-host-port" xml:"get-conn-host-port"` - GotConn time.Time `json:"got-conn" xml:"got-conn"` - GotFirstResponseByte time.Time `json:"got-first-response-byte" xml:"got-first-response-byte"` - TLSHandshakeStart time.Time `json:"tls-handshake-start,omitempty" xml:"tls-handshake-start,omitempty"` - TLSHandshakeDone time.Time `json:"tls-handshake-done,omitempty" xml:"tls-handshake-done,omitempty"` - TLSHandshakeDuration time.Duration `json:"tls-handshake-duration,omitempty" xml:"tls-handshake-duration,omitempty"` - TLSHandshakeState *tls.ConnectionState `json:"tls-handshake-state,omitempty" xml:"tls-handshake-state,omitempty"` - TLSHandshakeIssuer string `json:"tls-handshake-issuer,omitempty" xml:"tls-handshake-issuer,omitempty"` - TLSHandshakeSubject string `json:"tls-handshake-subject,omitempty" xml:"tls-handshake-subject,omitempty"` - TLSHandshakeNotBefore time.Time `json:"tls-handshake-not-before,omitempty" xml:"tls-handshake-not-before,omitempty"` - TLSHandshakeNotAfter time.Time `json:"tls-handshake-not-after,omitempty" xml:"tls-handshake-not-after,omitempty"` - TLSHandshakeDigest string `json:"tls-handshake-digest,omitempty" xml:"tls-handshake-digest,omitempty"` - WroteHeaders http.Header `json:"-" xml:"-" description:"http write header"` + + _, err := io.Copy(writer, w.Body) + return err + } } // NewClientParse 方法创建响应选项解析body数据。 -func NewClientParse(data interface{}) ClientResponseOption { +func NewClientParse(data any) func(*http.Response) error { return func(w *http.Response) error { return clientParseIn(w, 0, 0xffffffff, data) } } // NewClientParseIf 方法创建响应选项,在指定状态码时解析body数据。 -func NewClientParseIf(status int, data interface{}) ClientResponseOption { +func NewClientParseIf(status int, data any) func(*http.Response) error { return func(w *http.Response) error { return clientParseIn(w, status, status, data) } } // NewClientParseIn 方法创建响应选项,在指定状态码范围时解析body数据。 -func NewClientParseIn(star, end int, data interface{}) ClientResponseOption { +func NewClientParseIn(star, end int, data any) func(*http.Response) error { return func(w *http.Response) error { return clientParseIn(w, star, end, data) } } // NewClientParseErr 方法创建响应选项,在默认范围时解析body中的Error字段返回。 -func NewClientParseErr() ClientResponseOption { +func NewClientParseErr() func(*http.Response) error { return func(w *http.Response) error { var data struct { Status int `json:"status" protobuf:"6,name=status" xml:"status" yaml:"status"` @@ -522,14 +537,23 @@ func NewClientParseErr() ClientResponseOption { return err } if data.Error != "" { - return fmt.Errorf(data.Error) + return errors.New(data.Error) } return nil } } -func clientParseIn(w *http.Response, star, end int, data interface{}) error { - if w.StatusCode < star || w.StatusCode > end { +func clientParseIn(w *http.Response, star, end int, data any) error { + if w.StatusCode < star || w.StatusCode > end || w.Body == nil { + return nil + } + body, ok := data.(*string) + if ok { + data, err := io.ReadAll(w.Body) + if err != nil { + return err + } + *body = string(data) return nil } mime := w.Header.Get(HeaderContentType) @@ -545,262 +569,20 @@ func clientParseIn(w *http.Response, star, end int, data interface{}) error { case MimeApplicationProtobuf: return NewProtobufDecoder(w.Body).Decode(data) } - return fmt.Errorf("eudore client parse not suppert Content-Type: %s", mime) -} - -// NewClientCheckStatus 方法创建响应选项检查响应状态码。 -func NewClientCheckStatus(status ...int) ClientResponseOption { - return func(w *http.Response) error { - for i := range status { - if status[i] == w.StatusCode { - return nil - } - } - - err := fmt.Errorf("check status is %d not in %v", w.StatusCode, status) - r := w.Request - NewLoggerWithContext(w.Request.Context()).WithFields( - []string{"method", "host", "path", "query", "status-code", "status"}, - []interface{}{r.Method, r.Host, r.URL.Path, r.URL.RawQuery, w.StatusCode, w.Status}, - ).Error(err) - return err - } + return fmt.Errorf(ErrFormatClintParseBodyError, mime) } // NewClientCheckBody 方法创建响应选项检查响应body是否包含指定字符串。 -func NewClientCheckBody(str string) ClientResponseOption { +func NewClientCheckBody(str string) func(*http.Response) error { return func(w *http.Response) error { - body, err := ioutil.ReadAll(w.Body) + body, err := io.ReadAll(w.Body) if err != nil { return err } - w.Body = ioutil.NopCloser(bytes.NewReader(body)) - if strings.Index(string(body), str) == -1 { - err := fmt.Errorf("check body not have string '%s'", str) - r := w.Request - NewLoggerWithContext(w.Request.Context()).WithFields( - []string{"method", "host", "path", "query", "status-code", "status"}, - []interface{}{r.Method, r.Host, r.URL.Path, r.URL.RawQuery, w.StatusCode, w.Status}, - ).Error(err) - return err + w.Body = io.NopCloser(bytes.NewReader(body)) + if !strings.Contains(string(body), str) { + return fmt.Errorf("check body not have string '%s'"+string(body[:20]), str) } return nil } } - -// NewClientDumpHead 方法创建响应选项从环境上下文获取Logger输出请求基本信息和Trace信息。 -func NewClientDumpHead() ClientResponseOption { - return newClientDumpWithBody(false) -} - -// NewClientDumpBody 方法创建响应选项从环境上下文获取Logger输出响应head和body内容。 -func NewClientDumpBody() ClientResponseOption { - return newClientDumpWithBody(true) -} - -func newClientDumpWithBody(hasbody bool) ClientResponseOption { - return func(w *http.Response) error { - log := NewLoggerWithContext(w.Request.Context()) - if log == DefaultLoggerNull || log.GetLevel() > LoggerDebug { - return nil - } - r := w.Request - log = log.WithFields( - []string{"proto", "method", "host", "path", "query", "status-code", "status", "request-headers", "response-headers"}, - []interface{}{w.Proto, r.Method, r.Host, r.URL.Path, r.URL.RawQuery, w.StatusCode, w.Status, r.Header, w.Header}, - ) - - for _, name := range []string{HeaderXRequestID, HeaderXTraceID} { - id := w.Header.Get(name) - if id != "" { - log.WithField(strings.ToLower(name), id) - break - } - } - - trace, ok := w.Request.Context().Value(ContextKeyClientTrace).(*ClientTrace) - if ok { - trace.HTTPDone = time.Now() - trace.HTTPDuration = trace.HTTPDone.Sub(trace.HTTPStart) - trace.DNSDuration = trace.DNSDone.Sub(trace.DNSStart) - trace.ConnectDuration = trace.ConnectDone.Sub(trace.ConnectStart) - trace.TLSHandshakeDuration = trace.TLSHandshakeDone.Sub(trace.TLSHandshakeStart) - if trace.TLSHandshakeState != nil { - cert := trace.TLSHandshakeState.PeerCertificates[0] - h := sha256.New() - h.Write(cert.Raw) - trace.TLSHandshakeIssuer = cert.Issuer.String() - trace.TLSHandshakeSubject = cert.Subject.String() - trace.TLSHandshakeNotBefore = cert.NotBefore - trace.TLSHandshakeNotAfter = cert.NotAfter - trace.TLSHandshakeDigest = hex.EncodeToString(h.Sum(nil)) - trace.TLSHandshakeState = nil - } - log = log.WithFields([]string{"wrote-headers", "trace"}, []interface{}{trace.WroteHeaders, trace}) - } - - if hasbody { - body, err := ioutil.ReadAll(w.Body) - if err != nil { - return err - } - w.Body = ioutil.NopCloser(bytes.NewReader(body)) - log.WithField("body", string(body)) - } - log.Debug() - return nil - } -} - -func headerCopy(dst, src map[string][]string) map[string][]string { - for key, vals := range src { - dst[key] = append(dst[key], vals...) - } - return dst -} - -type bodyBuffer struct { - bytes.Buffer -} - -func (body *bodyBuffer) Clone() io.ReadCloser { - buf := &bodyBuffer{} - buf.Write(body.Bytes()) - return buf -} - -func (body *bodyBuffer) Close() error { - return nil -} - -type bodyJSON struct { - reader *io.PipeReader - writer *io.PipeWriter - data interface{} - values map[string]interface{} -} - -func (body *bodyJSON) Clone() io.ReadCloser { - return &bodyJSON{ - data: body.data, - values: body.values, - } -} - -func (body *bodyJSON) Read(p []byte) (n int, err error) { - if body.reader == nil { - body.reader, body.writer = io.Pipe() - go func() { - if body.data != nil { - json.NewEncoder(body.writer).Encode(body.data) - } else { - json.NewEncoder(body.writer).Encode(body.values) - } - body.writer.Close() - }() - } - return body.reader.Read(p) -} - -func (body *bodyJSON) Close() error { - return body.reader.Close() -} - -type bodyForm struct { - reader *io.PipeReader - writer *io.PipeWriter - Boundary string - Values map[string][]string - Files map[string][]fileContent -} - -type fileContent struct { - Name string - Body []byte - File string - Reader io.ReadCloser -} - -func (body *bodyForm) Clone() io.ReadCloser { - return &bodyForm{ - Boundary: body.Boundary, - Values: body.Values, - Files: body.Files, - } -} - -func (body *bodyForm) Read(p []byte) (n int, err error) { - if body.reader == nil { - body.reader, body.writer = io.Pipe() - w := multipart.NewWriter(body.writer) - w.SetBoundary(body.Boundary) - go func() { - for key, vals := range body.Values { - for _, val := range vals { - w.WriteField(key, val) - } - } - for key, vals := range body.Files { - for _, val := range vals { - part, _ := w.CreateFormFile(key, val.Name) - switch { - case val.Body != nil: - part.Write(val.Body) - case val.Reader != nil: - io.Copy(part, val.Reader) - val.Reader.Close() - case val.File != "": - file, err := os.Open(val.File) - if err == nil { - io.Copy(part, file) - file.Close() - } - } - } - } - w.Close() - body.writer.Close() - }() - } - return body.reader.Read(p) -} - -func (body *bodyForm) Close() error { - return body.reader.Close() -} - -type bodyEncoder struct { - reader *io.PipeReader - writer *io.PipeWriter - contenttype string - data interface{} -} - -func (body *bodyEncoder) Clone() io.ReadCloser { - return &bodyEncoder{ - contenttype: body.contenttype, - data: body.data, - } -} - -func (body *bodyEncoder) Read(p []byte) (n int, err error) { - if body.reader == nil { - body.reader, body.writer = io.Pipe() - go func() { - switch body.contenttype { - case MimeApplicationJSON: - json.NewEncoder(body.writer).Encode(body.data) - case MimeApplicationXML: - json.NewEncoder(body.writer).Encode(body.data) - case MimeApplicationProtobuf: - NewProtobufEncoder(body.writer).Encode(body.data) - } - body.writer.Close() - }() - } - return body.reader.Read(p) -} - -func (body *bodyEncoder) Close() error { - return body.reader.Close() -} diff --git a/clientoption.go b/clientoption.go new file mode 100644 index 0000000..9ef6d76 --- /dev/null +++ b/clientoption.go @@ -0,0 +1,573 @@ +package eudore + +import ( + "bytes" + "context" + "crypto/md5" + "crypto/sha256" + "crypto/tls" + "encoding/base64" + "encoding/hex" + "fmt" + "hash" + "io" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "strings" + "sync" + "time" + "unsafe" +) + +// ClientOption 定义创建客户端请求时额外选项。 +type ClientOption struct { + Context context.Context + Timeout time.Duration + Body io.Reader + ClientBody ClientBody + Values url.Values + Header http.Header + Headers []string + Cookies []string + RequestHooks []func(*http.Request) + ResponseHooks []func(*http.Response) error + Retrys []ClientRetry + // Trace saves ClientTrace data, and enables httptrace when it is not empty. + // Trace 保存ClientTrace数据,非空时启用httptrace。 + Trace *ClientTrace +} + +// ClientRetry 定义客户端请求重试行为。 +type ClientRetry struct { + Max int + Condition func(int, *http.Response, error) bool +} + +// ClientTrace 定义http客户端请求追踪记录的数据。 +type ClientTrace struct { + sync.Mutex `alias:"mutex" json:"-" xml:"-" yaml:"-"` + HTTPStart time.Time `alias:"http-start" json:"http-start" xml:"http-start" yaml:"http-start"` + HTTPDone time.Time `alias:"http-done" json:"http-done" xml:"http-done" yaml:"http-done"` + HTTPDuration time.Duration `alias:"http-duration" json:"http-duration" xml:"http-duration" yaml:"http-duration"` + DNSStart time.Time `alias:"dns-start,omitempty" json:"dns-start,omitempty" xml:"dns-start,omitempty" yaml:"dns-start,omitempty"` + DNSDone time.Time `alias:"dns-done,omitempty" json:"dns-done,omitempty" xml:"dns-done,omitempty" yaml:"dns-done,omitempty"` + DNSDuration time.Duration `alias:"dns-duration,omitempty" json:"dns-duration,omitempty" xml:"dns-duration,omitempty" yaml:"dns-duration,omitempty"` + DNSHost string `alias:"dns-host,omitempty" json:"dns-host,omitempty" xml:"dns-host,omitempty" yaml:"dns-host,omitempty"` + DNSAddrs []net.IPAddr `alias:"dns-addrs,omitempty" json:"dns-addrs,omitempty" xml:"dns-addrs,omitempty" yaml:"dns-addrs,omitempty"` + Connect []ClientTraceConnect `alias:"connect" json:"connect" xml:"connect" yaml:"connect"` + GetConn time.Time `alias:"get-conn" json:"get-conn" xml:"get-conn" yaml:"get-conn"` + GetConnHostPort string `alias:"get-conn-host-port" json:"get-conn-host-port" xml:"get-conn-host-port" yaml:"get-conn-host-port"` + GotConn time.Time `alias:"got-conn" json:"got-conn" xml:"got-conn" yaml:"got-conn"` + GotFirstResponseByte time.Time `alias:"got-first-response-byte" json:"got-first-response-byte" xml:"got-first-response-byte" yaml:"got-first-response-byte"` + TLSHandshakeStart time.Time `alias:"tls-handshake-start,omitempty" json:"tls-handshake-start,omitempty" xml:"tls-handshake-start,omitempty" yaml:"tls-handshake-start,omitempty"` + TLSHandshakeDone time.Time `alias:"tls-handshake-done,omitempty" json:"tls-handshake-done,omitempty" xml:"tls-handshake-done,omitempty" yaml:"tls-handshake-done,omitempty"` + TLSHandshakeDuration time.Duration `alias:"tls-handshake-duration,omitempty" json:"tls-handshake-duration,omitempty" xml:"tls-handshake-duration,omitempty" yaml:"tls-handshake-duration,omitempty"` + TLSHandshakeError error `alias:"tls-handshake-error,omitempty" json:"tls-handshake-error,omitempty" xml:"tls-handshake-error,omitempty" yaml:"tls-handshake-error,omitempty"` + TLSHandshakeIssuer string `alias:"tls-handshake-issuer,omitempty" json:"tls-handshake-issuer,omitempty" xml:"tls-handshake-issuer,omitempty" yaml:"tls-handshake-issuer,omitempty"` + TLSHandshakeSubject string `alias:"tls-handshake-subject,omitempty" json:"tls-handshake-subject,omitempty" xml:"tls-handshake-subject,omitempty" yaml:"tls-handshake-subject,omitempty"` + TLSHandshakeNotBefore time.Time `alias:"tls-handshake-not-before,omitempty" json:"tls-handshake-not-before,omitempty" xml:"tls-handshake-not-before,omitempty" yaml:"tls-handshake-not-before,omitempty"` + TLSHandshakeNotAfter time.Time `alias:"tls-handshake-not-after,omitempty" json:"tls-handshake-not-after,omitempty" xml:"tls-handshake-not-after,omitempty" yaml:"tls-handshake-not-after,omitempty"` + TLSHandshakeDigest string `alias:"tls-handshake-digest,omitempty" json:"tls-handshake-digest,omitempty" xml:"tls-handshake-digest,omitempty" yaml:"tls-handshake-digest,omitempty"` + WroteHeaders http.Header `alias:"wrote-headers,omitempty" json:"wrote-headers,omitempty" xml:"wrote-headers,omitempty" yaml:"wrote-headers,omitempty"` +} + +// ClientTraceConnect 定义Trace连接信息,一个请求可能出现多连接。 +type ClientTraceConnect struct { + Network string `alias:"network" json:"network" xml:"network" yaml:"network"` + Address string `alias:"address" json:"address" xml:"address" yaml:"address"` + Start time.Time `alias:"start" json:"start" xml:"start" yaml:"start"` + Done time.Time `alias:"done,omitempty" json:"done,omitempty" xml:"done,omitempty" yaml:"done,omitempty"` + Duration time.Duration `alias:"duration,omitempty" json:"duration,omitempty" xml:"duration,omitempty" yaml:"duration,omitempty"` + Error error `alias:"error,omitempty" json:"error,omitempty" xml:"error,omitempty" yaml:"error,omitempty"` +} + +// NewClientOption 函数使用options创建ClientOption。 +func NewClientOption(ctx context.Context, options []any) *ClientOption { + co := &ClientOption{} + return co.appendOptions(ctx, options) +} + +func appendValues(dst, src map[string][]string) map[string][]string { + if src == nil { + return dst + } + if dst == nil { + dst = make(map[string][]string) + } + for key, vals := range src { + dst[key] = append(dst[key], vals...) + } + return dst +} + +func (co *ClientOption) clone() *ClientOption { + o := &ClientOption{} + *o = *co + if o.Values != nil { + o.Values = appendValues(make(url.Values, len(o.Values)), o.Values) + } + if o.Header != nil { + o.Header = o.Header.Clone() + } + return o +} + +//nolint:cyclop,gocyclo +func (co *ClientOption) appendOptions(ctx context.Context, options []any) *ClientOption { + if ctx != nil { + co.Context = ctx + } + for i := range options { + switch o := options[i].(type) { + case context.Context: + co.Context = o + case ClientBody: + co.ClientBody = o + co.Body = o + case io.Reader: + co.Body = o + case url.Values: + co.Values = appendValues(co.Values, o) + case http.Header: + co.Header = appendValues(co.Header, o) + case Cookie: + co.Cookies = clearCap(append(co.Cookies, o.String())) + case *http.Cookie: + co.Cookies = clearCap(append(co.Cookies, Cookie{Name: o.Name, Value: o.Value}.String())) + case time.Duration: + co.Timeout = o + case func(*http.Request): + co.RequestHooks = clearCap(append(co.RequestHooks, o)) + case func(*http.Response) error: + co.ResponseHooks = clearCap(append(co.ResponseHooks, o)) //nolint:bodyclose + case ClientRetry: + co.Retrys = clearCap(append(co.Retrys, o)) + case *ClientTrace: + co.Trace = o + case *ClientOption: + co.append(o) + } + } + + return co +} + +func (co *ClientOption) append(o *ClientOption) { + if o == nil { + return + } + co.Context = GetAnyDefault(o.Context, co.Context) + co.Timeout = GetAnyDefault(o.Timeout, co.Timeout) + co.Body = GetAnyDefault(o.Body, co.Body) + co.ClientBody = GetAnyDefault(o.ClientBody, co.ClientBody) + + co.Values = appendValues(co.Values, o.Values) + co.Header = appendValues(co.Header, o.Header) + co.Headers = clearCap(append(co.Headers, o.Headers...)) + co.Cookies = clearCap(append(co.Cookies, o.Cookies...)) + co.RequestHooks = clearCap(append(co.RequestHooks, o.RequestHooks...)) + co.ResponseHooks = clearCap(append(co.ResponseHooks, o.ResponseHooks...)) //nolint:bodyclose + co.Retrys = clearCap(append(co.Retrys, o.Retrys...)) + if co.Trace == nil && o.Trace != nil { + co.Trace = &ClientTrace{} + } +} + +func (co *ClientOption) apply(req *http.Request) { + if co.Values != nil { + v, err := url.ParseQuery(req.URL.RawQuery) + if err == nil { + v = appendValues(v, co.Values) + req.URL.RawQuery = v.Encode() + } + } + if co.Header != nil { + req.Header = appendValues(req.Header, co.Header) + } + if co.Headers != nil { + for i := 0; i < len(co.Headers); i += 2 { + req.Header.Set(co.Headers[i], co.Headers[i+1]) + } + } + if co.Cookies != nil { + s := strings.Join(co.Cookies, "; ") + if c := req.Header.Get(HeaderCookie); c != "" { + req.Header.Set(HeaderCookie, c+"; "+s) + } else { + req.Header.Set(HeaderCookie, s) + } + } + + if co.ClientBody != nil { + req.ContentLength = -1 + req.Header.Set(HeaderContentType, co.ClientBody.GetContentType()) + req.GetBody = co.ClientBody.GetBody + } + for _, hook := range co.RequestHooks { + hook(req) + } +} + +var clientLoggerRequestIDKeys = [...]string{ + HeaderXRequestID, strings.ToLower(HeaderXRequestID), + HeaderXTraceID, strings.ToLower(HeaderXTraceID), +} + +func (co *ClientOption) release(req *http.Request, resp *http.Response, err error) error { + trace := co.Trace + if trace != nil { + trace.HTTPDone = time.Now() + trace.HTTPDuration = trace.HTTPDone.Sub(trace.HTTPStart) + } + for _, hook := range co.ResponseHooks { + if err != nil { + break + } + err = hook(resp) + } + + level := LoggerDebug + if err != nil { + level = LoggerError + } + if level >= DefaultClinetLoggerLevel { + log := NewLoggerWithContext(co.Context) + if log != DefaultLoggerNull && level >= log.GetLevel() { + keys := []string{"method", "scheme", "host", "path", "query", "request-header"} + vals := []any{req.Method, req.URL.Scheme, req.Host, req.URL.Path, req.URL.RawQuery, req.Header} + if resp != nil { + keys = append(keys, "proto", "status", "status-code", "response-header") + vals = append(vals, resp.Proto, resp.StatusCode, resp.Status, resp.Header) + // append id + for i := 0; i < 4; i += 2 { + val := resp.Header.Get(clientLoggerRequestIDKeys[i]) + if val != "" { + log = log.WithField(clientLoggerRequestIDKeys[i+1], val) + } + } + } + if trace != nil { + keys = append(keys, "trace") + vals = append(vals, trace) + // lock trace to loggerformat + trace.Lock() + defer trace.Unlock() + } + + if err != nil { + log.WithFields(keys, vals).Error(err.Error()) + } else { + log.WithFields(keys, vals).Debug() + } + } + } + + return err +} + +// NewClientOptionBasicauth 函数设置请求Basic Auth权限。 +func NewClientOptionBasicauth(username, password string) *ClientOption { + auth := username + ":" + password + return NewClientOptionHeader(HeaderAuthorization, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) +} + +// NewClientOptionBearer 函数设置请求Bearer认证。 +func NewClientOptionBearer(bearer string) *ClientOption { + return NewClientOptionHeader(HeaderAuthorization, "Bearer "+bearer) +} + +// NewClientOptionUserAgent 函数设置请求UA。 +func NewClientOptionUserAgent(ua string) *ClientOption { + return NewClientOptionHeader(HeaderUserAgent, ua) +} + +// NewClientOptionHost 函数设置请求Host。 +func NewClientOptionHost(host string) func(*http.Request) { + return func(req *http.Request) { req.Host = host } +} + +// NewClientOptionHeader 函数设置请求Header。 +func NewClientOptionHeader(key, val string) *ClientOption { + if val == "" { + return nil + } + return &ClientOption{ + Headers: []string{key, val}, + } +} + +// NewClientTraceWithContext 函数将ClientTrace初始化并绑定到context.Context。 +// +//nolint:funlen +func NewClientTraceWithContext(ctx context.Context, trace *ClientTrace) context.Context { + trace.HTTPStart = time.Now() + ctx = context.WithValue(ctx, ContextKeyClientTrace, trace) + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + DNSStart: func(info httptrace.DNSStartInfo) { + trace.DNSStart = time.Now() + trace.DNSHost = info.Host + }, + DNSDone: func(info httptrace.DNSDoneInfo) { + trace.DNSDone = time.Now() + trace.DNSDuration = trace.DNSDone.Sub(trace.DNSStart) + trace.DNSAddrs = info.Addrs + }, + ConnectStart: func(network, addr string) { + trace.Lock() + defer trace.Unlock() + trace.Connect = append(trace.Connect, ClientTraceConnect{ + Start: time.Now(), + Network: network, + Address: addr, + }) + }, + ConnectDone: func(network, addr string, err error) { + trace.Lock() + defer trace.Unlock() + for i := range trace.Connect { + if trace.Connect[i].Network == network && trace.Connect[i].Address == addr { + trace.Connect[i].Done = time.Now() + trace.Connect[i].Duration = trace.Connect[i].Done.Sub(trace.Connect[i].Start) + trace.Connect[i].Error = err + return + } + } + }, + GetConn: func(hostPort string) { + trace.GetConn = time.Now() + trace.GetConnHostPort = hostPort + }, + GotConn: func(httptrace.GotConnInfo) { trace.GotConn = time.Now() }, + GotFirstResponseByte: func() { trace.GotFirstResponseByte = time.Now() }, + TLSHandshakeStart: func() { trace.TLSHandshakeStart = time.Now() }, + TLSHandshakeDone: func(state tls.ConnectionState, err error) { + trace.Lock() + defer trace.Unlock() + trace.TLSHandshakeDone = time.Now() + trace.TLSHandshakeDuration = trace.TLSHandshakeDone.Sub(trace.TLSHandshakeStart) + trace.TLSHandshakeError = err + + if state.PeerCertificates != nil { + cert := state.PeerCertificates[0] + trace.TLSHandshakeIssuer = cert.Issuer.String() + trace.TLSHandshakeSubject = cert.Subject.String() + trace.TLSHandshakeNotBefore = cert.NotBefore + trace.TLSHandshakeNotAfter = cert.NotAfter + h := sha256.New() + h.Write(cert.Raw) + trace.TLSHandshakeDigest = hex.EncodeToString(h.Sum(nil)) + } + }, + WroteHeaderField: func(key string, value []string) { + if trace.WroteHeaders == nil { + trace.WroteHeaders = make(http.Header) + } + trace.WroteHeaders[key] = value + }, + }) + return ctx +} + +// NewClientRetryNetwork 函数创建一个网络重试配置。 +// +// 在err不为空或DefaultClinetRetryStatus指定状态码时,重试请求。 +func NewClientRetryNetwork(max int) ClientRetry { + return ClientRetry{ + Max: max, + Condition: func(attempt int, resp *http.Response, _ error) bool { + retry := resp == nil || DefaultClinetRetryStatus[resp.StatusCode] + if retry { + time.Sleep(time.Second * time.Duration((attempt + 1))) + } + return retry + }, + } +} + +// NewClientRetryDigest 函数创建一个摘要认证配置,在401时重新发起请求。 +func NewClientRetryDigest(username, password string) ClientRetry { + return ClientRetry{ + Max: 1, + Condition: func(_ int, resp *http.Response, _ error) bool { + if resp == nil || resp.StatusCode != StatusUnauthorized { + return false + } + + dig := newclientDigest(resp.Header.Get(HeaderWWWAuthenticate)) + if dig == nil || dig.invalid() { + return false + } + + req := resp.Request + if dig.Qop == httpDigestQopAuthInt && req.Body != nil { + // check GetBody in dotry + dig.Body, _ = req.GetBody() + } + + dig.Nc = "00000001" + dig.Username = username + dig.Password = password + dig.Method = req.Method + dig.URI = req.URL.Path + req.Header.Set(HeaderAuthorization, dig.Encode()) + return true + }, + } +} + +var ( + digestKeys = [...]string{ + "username", "uri", + "realm", "algorithm", "nonce", "qop", + "nc", "cnonce", "response", "opaque", + } + httpDigestQopAuth = "auth" + httpDigestQopAuthInt = "auth-int" +) + +type clientDigest struct { + Hash hash.Hash + Body io.ReadCloser + Password string + Method string + Username string + URI string + + Realm string + Algorithm string + Nonce string + Qop string + + Nc string + Cnonce string + Response string + Opaque string +} + +func newclientDigest(req string) *clientDigest { + if !strings.HasPrefix(req, "Digest ") { + return nil + } + req = req[7:] + + dig := &clientDigest{} + for _, s := range splitDigestString(req) { + k, v, ok := strings.Cut(s, "=") + if !ok { + return nil + } + if len(v) > 2 && v[0] == '"' && v[len(v)-1] == '"' { + v = v[1 : len(v)-1] + } + + switch k { + case "realm": + dig.Realm = v + case "algorithm": + dig.Algorithm = strings.ToUpper(v) + case "nonce": + dig.Nonce = v + case "qop": + dig.Qop = strings.TrimSpace(strings.SplitN(v, ",", 2)[0]) + case "opaque": + dig.Opaque = v + default: + return nil + } + } + + return dig +} + +func splitDigestString(str string) []string { + var pos int + var char bool + var strs []string + for i, b := range str { + switch b { + case ',': + if char { + continue + } + strs = append(strs, strings.TrimSpace(str[pos:i])) + pos = i + 1 + case '"': + char = !char + } + } + strs = append(strs, strings.TrimSpace(str[pos:])) + return strs +} + +func (dig *clientDigest) invalid() bool { + switch dig.Algorithm { + case "MD5", "MD5-SESS", "SHA-256", "SHA-256-SESS": + default: + return true + } + switch dig.Qop { + case "", httpDigestQopAuth, httpDigestQopAuthInt: + default: + return true + } + return false +} + +func (dig *clientDigest) Encode() string { + dig.Cnonce = GetStringRandom(40) + var ha1, ha2 string + switch dig.Algorithm { + case "MD5", "MD5-SESS": + dig.Hash = md5.New() + ha1 = dig.digestHash(fmt.Sprintf("%s:%s:%s", dig.Username, dig.Realm, dig.Password)) + case "SHA-256", "SHA-256-SESS": + dig.Hash = sha256.New() + ha1 = dig.digestHash(fmt.Sprintf("%s:%s:%s", dig.Username, dig.Realm, dig.Password)) + } + if strings.HasSuffix(dig.Algorithm, "-SESS") { + ha1 = dig.digestHash(fmt.Sprintf("%s:%s:%s", ha1, dig.Nonce, dig.Cnonce)) + } + + switch dig.Qop { + case httpDigestQopAuth, "": + ha2 = dig.digestHash(fmt.Sprintf("%s:%s", dig.Method, dig.URI)) + case httpDigestQopAuthInt: + if dig.Body != nil { + dig.Hash.Reset() + io.Copy(dig.Hash, dig.Body) + ha2 = hex.EncodeToString(dig.Hash.Sum(nil)) + dig.Body.Close() + } + ha2 = dig.digestHash(fmt.Sprintf("%s:%s:%s", dig.Method, dig.URI, ha2)) + } + + switch dig.Qop { + case httpDigestQopAuth, httpDigestQopAuthInt: + dig.Response = dig.digestHash(fmt.Sprintf("%s:%s:00000001:%s:%s:%s", ha1, dig.Nonce, dig.Cnonce, dig.Qop, ha2)) + case "": + dig.Response = dig.digestHash(fmt.Sprintf("%s:%s:%s", ha1, dig.Nonce, ha2)) + } + + buf := bytes.NewBufferString("Digest ") + data := *(*[14]string)(unsafe.Pointer(dig)) + for i, s := range data[4:] { + if s != "" { + switch i { + case 3, 5, 6: + fmt.Fprintf(buf, "%s=%s, ", digestKeys[i], s) + default: + fmt.Fprintf(buf, "%s=\"%s\", ", digestKeys[i], s) + } + } + } + buf.Truncate(buf.Len() - 2) + return buf.String() +} + +func (dig *clientDigest) digestHash(s string) string { + dig.Hash.Reset() + io.WriteString(dig.Hash, s) + return hex.EncodeToString(dig.Hash.Sum(nil)) +} diff --git a/config.go b/config.go index 287c01a..322203d 100644 --- a/config.go +++ b/config.go @@ -3,65 +3,71 @@ package eudore import ( "context" "encoding/json" + "errors" "fmt" "os" "reflect" + "regexp" "runtime" "strings" "sync" + "time" ) /* Config defines configuration management and uses configuration read-write and analysis functions. Get/Set read and write data implementation: -Use custom map or struct as data storage -Support Lock concurrency safety -Access attributes based on string path hierarchy + + Use custom struct or map as data storage + Support Lock concurrency safety + Access attributes based on string path hierarchy The default analysis function implementation: -Custom configuration analysis function -Parse multiple json files -Parse the length and short parameters of the command line -Parse Env environment variables -Switch working directory -Generate help information based on the structure + + Custom configuration analysis function + Parse multiple json files + Parse the length and short parameters of the command line + Parse Env environment variables + Switch working directory + Generate help information based on the structure Config 定义配置管理,使用配置读写和解析功能。 Get/Set读写数据实现下列功能: -使用自定义map或struct作为数据存储 -支持Lock并发安全 -基于字符串路径层次访问属性 + + 使用自定义struct或map作为数据存储 + 支持Lock并发安全 + 基于字符串路径层次访问属性 默认解析函数实现下列功能: -自定义配置解析函数 -解析多json文件 -解析命令行长短参数 -解析Env环境变量 -切换工作目录 -根据结构体生成帮助信息 + + 自定义配置解析函数 + 解析多json文件 + 解析命令行长短参数 + 解析Env环境变量 + 切换工作目录 + 根据结构体生成帮助信息 */ type Config interface { - Get(string) interface{} - Set(string, interface{}) error - ParseOption([]ConfigParseFunc) []ConfigParseFunc - Parse() error + Get(string) any + Set(string, any) error + ParseOption(...ConfigParseFunc) + Parse(context.Context) error } // ConfigParseFunc 定义配置解析函数。 // -// Config 默认解析函数为eudore.ConfigAllParseFunc +// Config 默认解析函数为eudore.ConfigAllParseFunc。 type ConfigParseFunc func(context.Context, Config) error // configStd 使用结构体或map保存配置,通过属性或反射来读写属性。 type configStd struct { - Context context.Context - Data interface{} `alias:"data" description:"all data"` - Map map[string]interface{} - Funcs []ConfigParseFunc `alias:"funcs" description:"config parse funcs"` - Err error `alias:"err" description:"config pasre error"` - rwLocker `alias:"-"` + Data any `alias:"data" json:"data" xml:"data" yaml:"data" description:"any data"` + Map map[string]any `alias:"map" json:"map" xml:"map" yaml:"map" description:"map data"` + Funcs []ConfigParseFunc `alias:"funcs" json:"funcs" xml:"funcs" yaml:"funcs" description:"all parse funcs"` + Err error `alias:"err" json:"err" xml:"err" yaml:"err" description:"parsing error"` + Lock rwLocker `alias:"lock" json:"-" xml:"-" yaml:"-"` } type rwLocker interface { @@ -70,10 +76,16 @@ type rwLocker interface { RUnlock() } -// NewConfigStd function creates a configStd. -// If the incoming parameter is empty, use map[string]interface{} as metadata. +type MetadataConfig struct { + Health bool `alias:"health" json:"health" xml:"health" yaml:"health"` + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` + Error error `alias:"error,omitempty" json:"error,omitempty" xml:"error,omitempty" yaml:"error,omitempty"` +} + +// NewConfig function creates a configStd. +// If the incoming parameter is empty, use map[string]any as metadata. // -// If the metadata type is map[string]interface{}, use map to read and write key values, +// If the metadata type is map[string]any, use map to read and write key values, // otherwise use eudore.Set and eudore.Get methods to read and write metadata. // // If the incoming configuration object implements the same read-write lock method as sync.RLock, @@ -81,129 +93,289 @@ type rwLocker interface { // // configStd has implemented the json.Marshaler and json.Unmarshaler interfaces. // -// NewConfigStd 函数创建一个configStd,如果传入参数为空,使用map[string]interface{}作为元数据。 +// NewConfig 函数创建一个configStd,如果传入参数为空,使用map[string]any作为元数据。 // -// 如果元数据类型为map[string]interface{}使用map读写键值,否则其他类型使用eudore.Set和eudore.Get方法去读元写数据。 +// 如果元数据类型为map[string]any使用map读写键值,否则其他类型使用eudore.Set和eudore.Get方法去读元写数据。 // // 如果传入的配置对象实现sync.RLock一样的读写锁方法,则使用配置的读写锁,否则会创建一个sync.RWMutex锁。 // // configStd已实现json.Marshaler和json.Unmarshaler接口. -func NewConfigStd(data interface{}) Config { +func NewConfig(data any) Config { if data == nil { - data = make(map[string]interface{}) + data = make(map[string]any) } mu, ok := data.(rwLocker) if !ok { - mu = new(sync.RWMutex) + mu = &sync.RWMutex{} } - m, _ := data.(map[string]interface{}) + m, _ := data.(map[string]any) return &configStd{ - Context: context.Background(), - Data: data, - Map: m, - Funcs: DefaultConfigAllParseFunc, - rwLocker: mu, + Data: data, + Map: m, + Funcs: DefaultConfigAllParseFunc, + Lock: mu, } } -// Mount 方法获取环境上下文。 -func (cnf *configStd) Mount(ctx context.Context) { - cnf.Context = ctx +func (conf *configStd) Metadata() any { + return MetadataConfig{ + Health: conf.Err == nil, + Name: "eudore.configStd", + Error: conf.Err, + } } // The Get method realizes to read the data attributes, and uses the RLock method to lock the data, // if key is empty string return metadata. // // Get 方法实现读取数据属性,并使用RLock方法锁定数据,如果key为空字符串返回元数据。 -func (cnf *configStd) Get(key string) interface{} { +func (conf *configStd) Get(key string) any { if len(key) == 0 { - return cnf.Data + return conf.Data } - cnf.RLock() - defer cnf.RUnlock() - if cnf.Map != nil { - return cnf.Map[key] + conf.Lock.RLock() + defer conf.Lock.RUnlock() + if conf.Map != nil { + return conf.Map[key] } - val, _ := GetWithTags(cnf.Data, key, DefaultConfigGetSetTags, false) - return val + return GetAnyByPath(conf.Data, key) } // The Set method implements setting data, and uses the Lock method to lock the data, // If key is empty string set metadata. // // Set 方法实现设置数据,并使用Lock方法锁定数据,如果key为空字符串设置元数据。 -func (cnf *configStd) Set(key string, val interface{}) error { - cnf.Lock() - defer cnf.Unlock() +func (conf *configStd) Set(key string, val any) error { + conf.Lock.Lock() + defer conf.Lock.Unlock() if len(key) == 0 { - cnf.Data = val - cnf.Map, _ = val.(map[string]interface{}) + conf.Data = val + conf.Map, _ = val.(map[string]any) return nil } - if cnf.Map != nil { - cnf.Map[key] = val + if conf.Map != nil { + conf.Map[key] = val return nil } - return SetWithTags(cnf.Data, key, val, DefaultConfigGetSetTags, false) + return SetAnyByPath(conf.Data, key, val) } // ParseOption executes a configuration parsing function option. // // ParseOption 执行一个配置解析函数选项。 -func (cnf *configStd) ParseOption(fn []ConfigParseFunc) []ConfigParseFunc { - cnf.Funcs, fn = fn, cnf.Funcs - cnf.Err = nil - return fn +func (conf *configStd) ParseOption(fn ...ConfigParseFunc) { + if fn == nil { + conf.Funcs = nil + conf.Err = nil + } else { + conf.Funcs = append(conf.Funcs, fn...) + } } // The Parse method executes all configuration parsing functions. // If the parsing function returns error, it stops parsing and returns error. // // Parse 方法执行全部配置解析函数,如果其中解析函数返回error,则停止解析并返回error。 -func (cnf *configStd) Parse() error { - if cnf.Err != nil { - return cnf.Err +func (conf *configStd) Parse(ctx context.Context) error { + if conf.Err != nil { + return conf.Err } - for _, fn := range cnf.Funcs { - cnf.Err = fn(cnf.Context, cnf) - if cnf.Err != nil { - NewLoggerWithContext(cnf.Context).Errorf("configStd parse func %v error: %v", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), cnf.Err) - return cnf.Err + log := NewLoggerWithContext(ctx) + for _, fn := range conf.Funcs { + conf.Err = fn(ctx, conf) + if conf.Err != nil { + if !errors.Is(conf.Err, context.Canceled) { + name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + log.Errorf("config parse func %v error: %v", name, conf.Err) + } + return conf.Err } } - NewLoggerWithContext(cnf.Context).Info("configStd parse done") + log.Info("config parse done") return nil } -// MarshalJSON implements the json.Marshaler interface, which enables json serialization to directly manipulate the saved data. +// MarshalJSON implements the json.Marshaler interface, +// which enables json serialization to directly manipulate the saved data. // // MarshalJSON 实现json.Marshaler接口,使json序列化直接操作保存的数据。 -func (cnf *configStd) MarshalJSON() ([]byte, error) { - cnf.RLock() - defer cnf.RUnlock() - return json.Marshal(cnf.Data) +func (conf *configStd) MarshalJSON() ([]byte, error) { + conf.Lock.RLock() + defer conf.Lock.RUnlock() + return json.Marshal(conf.Data) } -// UnmarshalJSON implements the json.Unmarshaler interface, which enables json deserialization to directly manipulate the saved data. +// UnmarshalJSON implements the json.Unmarshaler interface, +// which enables json deserialization to directly manipulate the saved data. // // UnmarshalJSON 实现json.Unmarshaler接口,使json反序列化直接操作保存的数据。 -func (cnf *configStd) UnmarshalJSON(data []byte) error { - cnf.Lock() - defer cnf.Unlock() - return json.Unmarshal(data, &cnf.Data) +func (conf *configStd) UnmarshalJSON(data []byte) error { + conf.Lock.Lock() + defer conf.Lock.Unlock() + return json.Unmarshal(data, &conf.Data) +} + +/* +The NewConfigParseEnvFile function creates an Env file configuration parsing method. + +If a line of the Env file is in env format from the beginning, +it will be loaded to the os as Env. + +If the first character of the Env value is "'" as a multi-line value, +until the end of a line also has "'"; +Newline characters "\r\n" "\n" in multi-line values are replaced with "\n" and TrimSpace is performed. + +If the Env value is an empty string, the Env will be deleted from os. + +NewConfigParseEnvFile 函数创建Env文件配置解析方法。 + +如果Env文件一行从开始为env格式,则作为Env加载到os。 + +如果Env值为第一个字符为"'"作为多行值,直到一行结尾同样具有"'"; +多行值中的换行符"\r\n" "\n"被替换为"\n",并执行TrimSpace。 + +如果Env值为空字符串会从os删除这个Env。 + +example: + + EUDORE_NAME=eudore + EUDORE_DEBUG= + EUDORE_KEY=' + -----BEGIN RSA PRIVATE KEY----- + -----END RSA PRIVATE KEY----- + ' +*/ +func NewConfigParseEnvFile(files ...string) ConfigParseFunc { + if files == nil { + files = strings.Split(DefaultConfigEnvFiles, ";") + } + reg := regexp.MustCompile(`[a-zA-Z]\w*`) + return func(ctx context.Context, c Config) error { + for _, file := range files { + data, err := os.ReadFile(file) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + log := NewLoggerWithContext(ctx) + log.Info("confif load env file", file) + lines := strings.Split(strings.ReplaceAll(string(data), "\r\n", "\n"), "\n") + keys := make([]string, 0, len(lines)) + char := "'" + for i := range lines { + key, val, ok := strings.Cut(lines[i], "=") + if ok && reg.MatchString(key) { + if strings.HasPrefix(val, char) { + if strings.HasSuffix(val, char) { + val = strings.TrimSpace(strings.TrimSuffix(val[1:], char)) + } else if i+1 < len(lines) { + lines[i+1] = lines[i] + "\n" + lines[i+1] + continue + } + } + + keys = append(keys, key) + log.Infof("set file environment: %s=%s", key, val) + if val != "" { + os.Setenv(key, val) + } else { + os.Unsetenv(key) + } + } + } + os.Setenv("EUDORE_CONFIG_LOAD_ENVS", strings.Join(keys, ",")) + } + return nil + } +} + +/* +The NewConfigParseDefault function creates a default variable parsing function +that gets the value from ENV to set the default variable. + +NewConfigParseDefault 函数创建一个默认变量解析函数,从ENV获取值设置默认变量。 + +env to keys: + + EUDORE_CONTEXT_MAX_HANDLER => DefaultContextMaxHandler + EUDORE_CONTEXT_MAX_APPLICATION_FORM_SIZE => DefaultContextMaxApplicationFormSize + EUDORE_CONTEXT_MAX_MULTIPART_FORM_MEMORY => DefaultContextMaxMultipartFormMemory + EUDORE_CONTEXT_FORM_MAX_MEMORY => DefaultContextFormMaxMemory + EUDORE_HANDLER_EMBED_CACHE_CONTROL => DefaultHandlerEmbedCacheControl + EUDORE_HANDLER_EMBED_TIME => DefaultHandlerEmbedTime + EUDORE_LOGGER_DEPTH_MAX_STACK => DefaultLoggerDepthMaxStack + EUDORE_LOGGER_ENABLE_HOOK_FATAL => DefaultLoggerEnableHookFatal + EUDORE_LOGGER_ENABLE_HOOK_META => DefaultLoggerEnableHookMeta + EUDORE_LOGGER_ENABLE_STD_COLOR => DefaultLoggerEnableStdColor + EUDORE_LOGGER_ENTRY_BUFFER_LENGTH => DefaultLoggerEntryBufferLength + EUDORE_LOGGER_ENTRY_FIELDS_LENGTH => DefaultLoggerEntryFieldsLength + EUDORE_LOGGER_FORMATTER => DefaultLoggerFormatter + EUDORE_LOGGER_FORMATTER_FORMAT_TIME => DefaultLoggerFormatterFormatTime + EUDORE_LOGGER_FORMATTER_KEY_LEVEL => DefaultLoggerFormatterKeyLevel + EUDORE_LOGGER_FORMATTER_KEY_MESSAGE => DefaultLoggerFormatterKeyMessage + EUDORE_LOGGER_FORMATTER_KEY_TIME => DefaultLoggerFormatterKeyTime + EUDORE_LOGGER_WRITER_STDOUT_WINDOWS_COLOR => DefaultLoggerWriterStdoutWindowsColor + EUDORE_ROUTER_LOGGER_KIND => DefaultRouterLoggerKind + EUDORE_SERVER_READ_TIMEOUT => DefaultServerReadTimeout + EUDORE_SERVER_READ_HEADER_TIMEOUT => DefaultServerReadHeaderTimeout + EUDORE_SERVER_WRITE_TIMEOUT => DefaultServerWriteTimeout + EUDORE_SERVER_IDLE_TIMEOUT => DefaultServerIdleTimeout + EUDORE_SERVER_SHUTDOWN_WAIT => DefaultServerShutdownWait + EUDORE_DAEMON_PIDFILE => DefaultDaemonPidfile + EUDORE_GODOC_SERVER => DefaultGodocServer + EUDORE_TRACE_SERVER => DefaultTraceServer +*/ +func NewConfigParseDefault() ConfigParseFunc { + return func(ctx context.Context, c Config) error { + parseEnvDefault(&DefaultContextMaxHandler, "CONTEXT_MAX_HANDLER") + parseEnvDefault(&DefaultContextMaxApplicationFormSize, "CONTEXT_MAX_APPLICATION_FORM_SIZE") + parseEnvDefault(&DefaultContextMaxMultipartFormMemory, "CONTEXT_MAX_MULTIPART_FORM_MEMORY") + parseEnvDefault(&DefaultHandlerEmbedCacheControl, "HANDLER_EMBED_CACHE_CONTROL") + parseEnvDefault(&DefaultHandlerEmbedTime, "HANDLER_EMBED_TIME") + parseEnvDefault(&DefaultLoggerDepthMaxStack, "LOGGER_DEPTH_MAX_STACK") + parseEnvDefault(&DefaultLoggerEnableHookFatal, "LOGGER_ENABLE_HOOK_FATAL") + parseEnvDefault(&DefaultLoggerEnableHookMeta, "LOGGER_ENABLE_HOOK_META") + parseEnvDefault(&DefaultLoggerEnableStdColor, "LOGGER_ENABLE_STD_COLOR") + parseEnvDefault(&DefaultLoggerEntryBufferLength, "LOGGER_ENTRY_BUFFER_LENGTH") + parseEnvDefault(&DefaultLoggerEntryFieldsLength, "LOGGER_ENTRY_FIELDS_LENGTH") + parseEnvDefault(&DefaultLoggerFormatter, "LOGGER_FORMATTER") + parseEnvDefault(&DefaultLoggerFormatterFormatTime, "LOGGER_FORMATTER_FORMAT_TIME") + parseEnvDefault(&DefaultLoggerFormatterKeyLevel, "LOGGER_FORMATTER_KEY_LEVEL") + parseEnvDefault(&DefaultLoggerFormatterKeyMessage, "LOGGER_FORMATTER_KEY_MESSAGE") + parseEnvDefault(&DefaultLoggerFormatterKeyTime, "LOGGER_FORMATTER_KEY_TIME") + parseEnvDefault(&DefaultLoggerWriterStdoutWindowsColor, "LOGGER_WRITER_STDOUT_WINDOWS_COLOR") + parseEnvDefault(&DefaultRouterLoggerKind, "ROUTER_LOGGER_KIND") + parseEnvDefault(&DefaultServerReadTimeout, "SERVER_READ_TIMEOUT") + parseEnvDefault(&DefaultServerReadHeaderTimeout, "SERVER_READ_HEADER_TIMEOUT") + parseEnvDefault(&DefaultServerWriteTimeout, "SERVER_WRITE_TIMEOUT") + parseEnvDefault(&DefaultServerIdleTimeout, "SERVER_IDLE_TIMEOUT") + parseEnvDefault(&DefaultServerShutdownWait, "SERVER_SHUTDOWN_WAIT") + parseEnvDefault(&DefaultDaemonPidfile, "DAEMON_PIDFILE") + parseEnvDefault(&DefaultGodocServer, "GODOC_SERVER") + parseEnvDefault(&DefaultTraceServer, "TRACE_SERVER") + return nil + } +} + +func parseEnvDefault[T string | bool | TypeNumber | time.Time | time.Duration](val *T, key string) { + *val = GetAnyByString(os.Getenv("EUDORE_"+key), *val) } // NewConfigParseJSON method parses the json file configuration, usually the key is "config". // -// The configuration item value is string(';' divided into multiple paths) or []string, if the loaded file does not exist, the file will be ignored. +// The configuration item value is string(';' divided into multiple paths) or []string, +// if the loaded file does not exist, the file will be ignored. // // NewConfigParseJSON 方法解析json文件配置,通常使用key为"config"。 // // 配置项值为string(';'分割为多路径)或[]string,如果加载文件不存在将忽略文件。 func NewConfigParseJSON(key string) ConfigParseFunc { - return func(ctx context.Context, cnf Config) error { + return func(ctx context.Context, conf Config) error { var paths []string - switch val := cnf.Get(key).(type) { + switch val := conf.Get(key).(type) { case string: paths = strings.Split(val, ";") case []string: @@ -221,9 +393,9 @@ func NewConfigParseJSON(key string) ConfigParseFunc { continue } defer file.Close() - err = json.NewDecoder(file).Decode(cnf) + err = json.NewDecoder(file).Decode(conf) if err != nil { - err = fmt.Errorf("config parse json file '%s' error: %v", path, err) + err = fmt.Errorf("config parse json file '%s' error: %w", path, err) log.Info(err) return err } @@ -236,18 +408,21 @@ func NewConfigParseJSON(key string) ConfigParseFunc { // NewConfigParseArgs function uses the eudore.Set method to set the command line parameter data, // and the command line parameter uses the format of'--{key}.{sub}={value}'. // -// Shortsmap is mapped as a short parameter. If the structure has a'flag' tag, it will be used as the abbreviation of the path. -// The tag length must be less than 5, the command line format is'-{short}={value}, and the short parameter will automatically be long parameter. +// Shortsmap is mapped as a short parameter. If the structure has a'flag' tag, +// it will be used as the abbreviation of the path. +// The tag length must be less than 5, the command line format is'-{short}={value}, +// and the short parameter will automatically be long parameter. // // NewConfigParseArgs 函数使用eudore.Set方法设置命令行参数数据,命令行参数使用'--{key}.{sub}={value}'格式。 // -// shortsmap作为短参数映射,如果结构体存在'flag' tag将作为该路径的缩写,tag长度需要小于5,命令行格式为'-{short}={value},短参数将会自动为长参数。 +// shortsmap作为短参数映射,如果结构体存在'flag' tag将作为该路径的缩写, +// tag长度需要小于5,命令行格式为'-{short}={value},短参数将会自动为长参数。 func NewConfigParseArgs(shortsmap map[string][]string) ConfigParseFunc { - return func(ctx context.Context, cnf Config) error { + return func(ctx context.Context, conf Config) error { // 使用结构体tag初始化shorts shorts := make(map[string][]string) flag := &eachTags{tag: "flag", Repeat: make(map[uintptr]string)} - flag.Each("", reflect.ValueOf(cnf.Get(""))) + flag.Each("", reflect.ValueOf(conf.Get(""))) for i, tag := range flag.Tags { shorts[flag.Vals[i]] = append(shorts[flag.Vals[i]], tag[1:]) } @@ -255,60 +430,63 @@ func NewConfigParseArgs(shortsmap map[string][]string) ConfigParseFunc { shorts[k] = append(shorts[k], v...) } + args := []string{} log := NewLoggerWithContext(ctx) for _, str := range os.Args[1:] { - key, val := split2byte(str, '=') - if strings.HasPrefix(key, "--") { // 长参数 - if val == "" && reflect.ValueOf(cnf.Get(key[2:])).Kind() == reflect.Bool { - val = "true" - } - log.Info("config set arg: " + str) - cnf.Set(key[2:], val) - } else if len(key) > 1 && key[0] == '-' && key[1] != '-' { // 短参数 + key, val, _ := strings.Cut(str, "=") + switch { + case strings.HasPrefix(key, "--"): // 长参数 + log.Info("set os argument: " + str) + conf.Set(key[2:], val) + case len(key) > 1 && key[0] == '-' && key[1] != '-': // 短参数 for _, lkey := range shorts[key[1:]] { - val := val - if val == "" && reflect.ValueOf(cnf.Get(lkey)).Kind() == reflect.Bool { - val = "true" - } - log.Infof("config set short arg '%s': --%s=%s", key[1:], lkey, val) - cnf.Set(lkey, val) + log.Infof("set os short argument '%s': --%s=%s", key[1:], lkey, val) + conf.Set(lkey, val) } + default: + args = append(args, str) } } + conf.Set("args", args) return nil } } -// NewConfigParseEnvs function uses the eudore.Set method to set the environment variable data, usually the environment variable prefix uses'ENV_'. +// NewConfigParseEnvs function uses the eudore.Set method to set the environment variable data, +// usually the environment variable prefix uses'ENV_'. // -// Environment variables will be converted to lowercase paths, and the underscore of'_' is equivalent to the function of'.'. +// Environment variables will be converted to lowercase paths, +// and the underscore of'_' is equivalent to the function of'.'. // -// NewConfigParseEnvs 函数使用eudore.Set方法设置环境变量数据,通常环境变量前缀使用'ENV_'。 +// NewConfigParseEnvs 函数使用eudore.Set方法设置环境变量数据,环境变量默认前缀使用'ENV_'。 // -// 环境变量将转换成小写路径,'_'下划线相当于'.'的作用 +// 环境变量将移除前缀转换成小写路径,'_'下划线相当于'.'的作用 // // exmapel: 'ENV_EUDORE_NAME=eudore' => 'eudore.name=eudore'。 -func NewConfigParseEnvs(key string) ConfigParseFunc { - return func(ctx context.Context, cnf Config) error { +func NewConfigParseEnvs(prefix string) ConfigParseFunc { + l := len(prefix) + return func(ctx context.Context, conf Config) error { log := NewLoggerWithContext(ctx) - for _, value := range os.Environ() { - if strings.HasPrefix(value, "ENV_") { - log.Info("config set env: " + value) - k, v := split2byte(value, '=') - k = strings.ToLower(strings.Replace(k, "_", ".", -1))[4:] - cnf.Set(k, v) + for _, env := range os.Environ() { + if strings.HasPrefix(env, prefix) { + log.Infof("set os environment: %s", env) + k, v, _ := strings.Cut(env, "=") + if k != "" { + conf.Set(strings.ToLower(strings.ReplaceAll(k[l:], "_", ".")), v) + } } } return nil } } -// NewConfigParseWorkdir function initializes the workspace, usually using the key as string("workdir") to obtain the workspace directory and switch. +// NewConfigParseWorkdir function initializes the workspace, +// usually using the key as string("workdir") to obtain the workspace directory and switch. // // NewConfigParseWorkdir 函数初始化工作空间,通常使用key为string("workdir"),获取工作空间目录并切换。 func NewConfigParseWorkdir(key string) ConfigParseFunc { - return func(ctx context.Context, cnf Config) error { - dir, ok := cnf.Get(key).(string) + return func(ctx context.Context, conf Config) error { + dir, ok := conf.Get(key).(string) if ok && dir != "" { NewLoggerWithContext(ctx).Info("changes working directory to: " + dir) return os.Chdir(dir) @@ -317,11 +495,14 @@ func NewConfigParseWorkdir(key string) ConfigParseFunc { } } -// NewConfigParseHelp function if uses the structure configuration to output the'flag' and'description' tags to produce the default parameter description. +// NewConfigParseHelp function if uses the structure configuration to output the'flag' +// and'description' tags to produce the default parameter description. // -// By default, only the parameter description is output. For other descriptions, please wrap the NewConfigParseHelp method. +// By default, only the parameter description is output. For other descriptions, +// please wrap the NewConfigParseHelp method. // -// Note that the properties of the configuration structure need to be non-empty, otherwise it will not enter the traversal. +// Note that the properties of the configuration structure need to be non-empty, +// otherwise it will not enter the traversal. // // NewConfigParseHelp 函数如果使用结构体配置输出'flag'和'description' tag生产默认参数描述。 // @@ -329,22 +510,21 @@ func NewConfigParseWorkdir(key string) ConfigParseFunc { // // 注意配置结构体的属性需要是非空,否则不会进入遍历。 func NewConfigParseHelp(key string) ConfigParseFunc { - return func(ctx context.Context, cnf Config) error { - help, ok := cnf.Get(key).(bool) - if !ok || !help { + return func(ctx context.Context, conf Config) error { + if !GetAny[bool](conf.Get(key)) { return nil } - conf := reflect.ValueOf(cnf.Get("")) + data := reflect.ValueOf(conf.Get("")) flag := &eachTags{tag: "flag", Repeat: make(map[uintptr]string)} - flag.Each("", conf) + flag.Each("", data) flagmap := make(map[string]string) for i, tag := range flag.Tags { flagmap[tag[1:]] = flag.Vals[i] } desc := &eachTags{tag: "description", Repeat: make(map[uintptr]string)} - desc.Each("", conf) + desc.Each("", data) var length int for i, tag := range desc.Tags { desc.Tags[i] = tag[1:] @@ -356,11 +536,11 @@ func NewConfigParseHelp(key string) ConfigParseFunc { for i, tag := range desc.Tags { f, ok := flagmap[tag] if ok && !strings.Contains(tag, "{") && len(f) < 5 { - fmt.Printf(" -%s,", f) + fmt.Printf(" -%s,", f) //nolint:forbidigo } - fmt.Printf("\t --%s=%s\t%s\n", tag, strings.Repeat(" ", length-len(tag)), desc.Vals[i]) + fmt.Printf("\t --%s=%s\t%s\r\n", tag, strings.Repeat(" ", length-len(tag)), desc.Vals[i]) //nolint:forbidigo } - return nil + return context.Canceled } } @@ -372,32 +552,32 @@ type eachTags struct { LastTag string } -func (each *eachTags) Each(prefix string, iValue reflect.Value) { - switch iValue.Kind() { +func (each *eachTags) Each(prefix string, v reflect.Value) { + switch v.Kind() { case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: - if !iValue.IsNil() { - _, ok := each.Repeat[iValue.Pointer()] + if !v.IsNil() { + _, ok := each.Repeat[v.Pointer()] if ok { return } - each.Repeat[iValue.Pointer()] = prefix + each.Repeat[v.Pointer()] = prefix } } - switch iValue.Kind() { + switch v.Kind() { case reflect.Ptr, reflect.Interface: - if !iValue.IsNil() { - each.Each(prefix, iValue.Elem()) + if !v.IsNil() { + each.Each(prefix, v.Elem()) } case reflect.Map: if each.LastTag != "" { - each.Tags = append(each.Tags, fmt.Sprintf("%s.{%s}", prefix, iValue.Type().Key().Name())) + each.Tags = append(each.Tags, fmt.Sprintf("%s.{%s}", prefix, v.Type().Key().Name())) each.Vals = append(each.Vals, each.LastTag) } case reflect.Slice, reflect.Array: length := "n" - if iValue.Kind() == reflect.Array { - length = fmt.Sprint(iValue.Type().Len() - 1) + if v.Kind() == reflect.Array { + length = fmt.Sprint(v.Type().Len() - 1) } last := each.LastTag if last != "" { @@ -405,16 +585,16 @@ func (each *eachTags) Each(prefix string, iValue reflect.Value) { each.Vals = append(each.Vals, last) } each.LastTag = last - each.Each(fmt.Sprintf("%s.{0-%s}", prefix, length), reflect.New(iValue.Type().Elem())) + each.Each(fmt.Sprintf("%s.{0-%s}", prefix, length), reflect.New(v.Type().Elem())) case reflect.Struct: - each.EachStruct(prefix, iValue) + each.EachStruct(prefix, v) } } -func (each *eachTags) EachStruct(prefix string, iValue reflect.Value) { - iType := iValue.Type() +func (each *eachTags) EachStruct(prefix string, v reflect.Value) { + iType := v.Type() for i := 0; i < iType.NumField(); i++ { - if iValue.Field(i).CanSet() { + if v.Field(i).CanSet() { val := iType.Field(i).Tag.Get(each.tag) name := iType.Field(i).Tag.Get("alias") if name == "" { @@ -425,7 +605,7 @@ func (each *eachTags) EachStruct(prefix string, iValue reflect.Value) { each.Vals = append(each.Vals, val) } each.LastTag = val - each.Each(prefix+"."+name, iValue.Field(i)) + each.Each(prefix+"."+name, v.Field(i)) } } } @@ -434,7 +614,8 @@ func (each *eachTags) getValueKind(iType reflect.Type) string { switch iType.Kind() { case reflect.Bool: return "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return "int" case reflect.Float32, reflect.Float64: return "float" diff --git a/const.go b/const.go index 223643d..e3bd6bd 100644 --- a/const.go +++ b/const.go @@ -3,16 +3,61 @@ package eudore // const定义全部全局变量和常量 import ( + "encoding" + "encoding/json" "errors" "fmt" - "html/template" + "net" "reflect" "time" ) +var ( + // 定义各种类型的反射类型。 + typeAny = reflect.TypeOf((*any)(nil)).Elem() + typeError = reflect.TypeOf((*error)(nil)).Elem() + typeContext = reflect.TypeOf((*Context)(nil)).Elem() + typeHandlerFunc = reflect.TypeOf((*HandlerFunc)(nil)).Elem() + typeTimeDuration = reflect.TypeOf((*time.Duration)(nil)).Elem() + typeTimeTime = reflect.TypeOf((*time.Time)(nil)).Elem() + typeFmtStringer = reflect.TypeOf((*fmt.Stringer)(nil)).Elem() + typeJSONMarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + typeTextMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + // 检测各类接口。 + _ Client = (*clientStd)(nil) + _ Config = (*configStd)(nil) + _ Context = (*contextBase)(nil) + _ Controller = (*ControllerAutoRoute)(nil) + _ Controller = (*controllerError)(nil) + _ FuncCreator = (*funcCreatorBase)(nil) + _ FuncCreator = (*funcCreatorExpr)(nil) + _ HandlerExtender = (*handlerExtenderBase)(nil) + _ HandlerExtender = (*handlerExtenderTree)(nil) + _ HandlerExtender = (*handlerExtenderWarp)(nil) + _ Logger = (*loggerStd)(nil) + _ LoggerHandler = (*loggerFormatterJSON)(nil) + _ LoggerHandler = (*loggerFormatterText)(nil) + _ LoggerHandler = (*loggerHandlerInit)(nil) + _ LoggerHandler = (*loggerHookFilter)(nil) + _ LoggerHandler = (*loggerHookMeta)(nil) + _ LoggerHandler = (*loggerWriterFile)(nil) + _ LoggerHandler = (*loggerWriterRotate)(nil) + _ LoggerHandler = (*loggerWriterStdoutColor)(nil) + _ LoggerHandler = (*loggerWriterStdout)(nil) + _ ResponseWriter = (*responseWriterHTTP)(nil) + _ Router = (*RouterStd)(nil) + _ RouterCore = (*routerCoreHost)(nil) + _ RouterCore = (*routerCoreLock)(nil) + _ RouterCore = (*routerCoreStd)(nil) + _ Server = (*serverFcgi)(nil) + _ Server = (*serverStd)(nil) +) + var ( // ContextKeyApp 定义获取app的Key。 ContextKeyApp = NewContextKey("app") + // ContextKeyAppKeys 定义获取app全部可获取数据keys的Key。 + ContextKeyAppKeys = NewContextKey("app-keys") // ContextKeyLogger 定义获取logger的Key。 ContextKeyLogger = NewContextKey("logger") // ContextKeyConfig 定义获取config的Key。 @@ -33,201 +78,275 @@ var ( ContextKeyError = NewContextKey("error") // ContextKeyBind 定义获取bind的Key。 ContextKeyBind = NewContextKey("bind") - // ContextKeyValidate 定义获取validate的Key。 - ContextKeyValidate = NewContextKey("validate") - // ContextKeyFilte 定义获取filte的Key。 - ContextKeyFilte = NewContextKey("filte") + // ContextKeyValidater 定义获取validate的Key。 + ContextKeyValidater = NewContextKey("validater") + // ContextKeyFilter 定义获取filte的Key。 + ContextKeyFilter = NewContextKey("filter") + // ContextKeyFilterRules 定义获取filte-data的Key。 + ContextKeyFilterRules = NewContextKey("filter-rules") // ContextKeyRender 定义获取render的Key。 ContextKeyRender = NewContextKey("render") + // ContextKeyHandlerExtender 定义获取handler-extender的Key。 + ContextKeyHandlerExtender = NewContextKey("handler-extender") // ContextKeyFuncCreator 定义获取func-creator的Key。 ContextKeyFuncCreator = NewContextKey("func-creator") // ContextKeyTemplate 定义获取templdate的Key。 ContextKeyTemplate = NewContextKey("templdate") // ContextKeyTrace 定义获取trace的Key。 ContextKeyTrace = NewContextKey("trace") - - // DefaultBindFormTags 定义bind form使用tags。 - DefaultBindFormTags = []string{"form", "alias"} - // DefaultBindHeaderTags 定义bind header使用tags。 - DefaultBindHeaderTags = []string{"header", "alias"} - // DefaultBindURLTags 定义bind url使用tags。 - DefaultBindURLTags = []string{"url", "alias"} - // DefaultClientBodyContextType 定义NewClientBody默认使用的内容类型。 - DefaultClientBodyContextType = MimeApplicationJSONCharsetUtf8 + // ContextKeyDaemonCommand 定义获取daemon-command的Key。 + ContextKeyDaemonCommand = NewContextKey("daemon-command") + // ContextKeyDaemonSignal 定义获取daemon-signal的Key。 + ContextKeyDaemonSignal = NewContextKey("daemon-signal") + // ContextKeyDatabaseRuntime 定义获取database-runtime的Key。 + ContextKeyDatabaseRuntime = NewContextKey("database-runtime") + // DefaultClientDialKeepAlive 定义默认DialContext超时时间。 + DefaultClientDialKeepAlive = 30 * time.Second + // DefaultClientDialTimeout 定义默认DialContext超时时间。 + DefaultClientDialTimeout = 30 * time.Second // DefaultClientHost 定义clientStd默认使用的Host。 DefaultClientHost = "localhost:80" - // DefaultClientParseErrStar 定义NewClientParseErr解析err的状态码范围 - DefaultClientParseErrStar = 500 - // DefaultClientParseErrEnd 定义NewClientParseErr解析err的状态码范围 - DefaultClientParseErrEnd = 500 // DefaultClientInternalHost 定义clientStd使用内部连接的Host。 DefaultClientInternalHost = "127.0.0.10:80" - // DefaultConfigAllParseFunc 定义ConfigMap和ConfigEudore默认使用的解析函数。 - DefaultConfigAllParseFunc = []ConfigParseFunc{NewConfigParseJSON("config"), NewConfigParseArgs(nil), - NewConfigParseEnvs("ENV_"), NewConfigParseWorkdir("workdir"), NewConfigParseHelp("help")} - // DefaultConfigGetSetTags 定义ConfigStd默认使用GetSet的tag。 - DefaultConfigGetSetTags = []string{"alias"} + // DefaultClientParseErrStar 定义NewClientParseErr解析err的状态码范围。 + DefaultClientParseErrStar = 500 + // DefaultClientParseErrEnd 定义NewClientParseErr解析err的状态码范围。 + DefaultClientParseErrEnd = 500 + // DefaultClientTimeout 定义客户端默认超时时间。 + DefaultClientTimeout = 30 * time.Second + // DefaultClinetHopHeaders 定义Hop to Hop Header。 + DefaultClinetHopHeaders = [...]string{ + HeaderConnection, + HeaderUpgrade, + HeaderKeepAlive, + HeaderProxyConnection, + HeaderProxyAuthenticate, + HeaderProxyAuthorization, + HeaderTE, + HeaderTrailer, + HeaderTransferEncoding, + } + // DefaultClinetLoggerLevel 定义Client默认最小输出日志级别。 + DefaultClinetLoggerLevel = LoggerError + // DefaultClinetRetryStatus 定义NewClientRetryNetwork重试状态码。 + DefaultClinetRetryStatus = map[int]bool{ + StatusTooManyRequests: true, + StauusClientClosedRequest: true, + StatusBadGateway: true, + StatusServiceUnavailable: true, + StatusGatewayTimeout: true, + } + // DefaultConfigAllParseFunc 定义Config默认使用的解析函数。 + DefaultConfigAllParseFunc = []ConfigParseFunc{ + NewConfigParseEnvFile(), + NewConfigParseDefault(), + NewConfigParseJSON("config"), + NewConfigParseArgs(nil), + NewConfigParseEnvs("ENV_"), + NewConfigParseWorkdir("workdir"), + NewConfigParseHelp("help"), + } + // DefaultConfigEnvFiles 定义NewConfigParseEnvFile函数默认读取ENV文件。 + DefaultConfigEnvFiles = ".env" // DefaultContextMaxHandler 定义请求上下文handler数量上限,需要小于该值。 DefaultContextMaxHandler = 0xff - // DefaultContextFormMaxMemory 默认解析From body使用内存。 - DefaultContextFormMaxMemory int64 = 32 << 20 // 32 MB - // DefaultEmbedCacheControl 定义默认NewHandlerEmbedFunc使用的Cache-Control缓存策略 - DefaultEmbedCacheControl = "no-cache" - // DefaultEmbedTime 设置http返回embed文件的最后修改时间,默认为服务启动时间。 - // 如果服务存在多副本部署,通过设置相同的值保持多副本间的时间版本一致。 - DefaultEmbedTime = time.Now() - // DefaultFuncCreator 定义全局默认FuncCreator,RouetrCoreStd默认使用。 + // DefaultContextMaxApplicationFormSize 默认解析ApplicationFrom时body限制长度; + // 如果Body实现Limit() int64方法忽略该值。 + DefaultContextMaxApplicationFormSize int64 = 10 << 20 // 10M + // DefaultContextMaxMultipartFormMemory 默认解析MultipartFrom时body使用内存大小。 + DefaultContextMaxMultipartFormMemory int64 = 32 << 20 // 32 MB + // DefaultContextPushNotSupportedError 定义Context.Push时是否输出http.ErrNotSupported错误。 + DefaultContextPushNotSupportedError = true + // DefaultFuncCreator 定义全局默认FuncCreator, RouetrCoreStd默认使用。 DefaultFuncCreator = NewFuncCreator() - // DefaultGodocServer 定义应用默认使用的godoc服务器域名。 - DefaultGodocServer = "https://golang.org" - // DefaultHandlerExtend 为默认的函数扩展处理者,是RouterStd使用的最顶级的函数扩展处理者。 - DefaultHandlerExtend = NewHandlerExtendBase() - // DefaultHandlerExtendAllowType 定义handlerExtendBase允许使用的参数类型。 - DefaultHandlerExtendAllowType = map[reflect.Kind]struct{}{reflect.Func: {}, reflect.Interface: {}, - reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, reflect.Struct: {}} - // DefaultLoggerDepth 定义GetPanicStack函数默认显示栈最大层数。 - DefaultLoggerDepth = 64 - // DefaultLoggerLevelString 定义日志级别输出字符串。 - DefaultLoggerLevelString = [5]string{"DEBUG", "INFO", "WARNING", "ERROR", "FATAL"} + // DefaultHandlerBindFormTags 定义bind form使用tags。 + DefaultHandlerBindFormTags = []string{"form", "alias"} + // DefaultHandlerBindHeaderTags 定义bind header使用tags。 + DefaultHandlerBindHeaderTags = []string{"header", "alias"} + // DefaultHandlerBindURLTags 定义bind url使用tags。 + DefaultHandlerBindURLTags = []string{"url", "alias"} + // DefaultHandlerDataCode 定义Bind/Validate/Filter/Render返回错误时使用的自定义Code。 + DefaultHandlerDataCode = [4]int{} + // DefaultHandlerDataStatus 定义Bind/Validate/Filter/Render返回错误时使用的自定义Status。 + DefaultHandlerDataStatus = [4]int{} + // DefaultHandlerRenderFunc 定义默认使用的Render函数。 + DefaultHandlerRenderFunc = RenderJSON + // DefaultHandlerValidateTag 定义NewValidateField获取校验规则的结构体tag。 + DefaultHandlerValidateTag = "validate" + // DefaultHandlerEmbedCacheControl 定义默认NewHandlerEmbedFunc使用的Cache-Control缓存策略。 + DefaultHandlerEmbedCacheControl = "no-cache" + // DefaultHandlerEmbedTime 设置http返回embed文件的最后修改时间,默认为服务启动时间。 + // 如果服务存在多副本部署,通过设置相同的值使多副本间的时间版本一致,保证启用304缓存。 + DefaultHandlerEmbedTime = time.Now() + // DefaultHandlerExtender 为默认的函数扩展处理者。 + DefaultHandlerExtender = NewHandlerExtender() + // DefaultHandlerExtenderAllowType 定义handlerExtenderBase允许使用的参数类型。 + DefaultHandlerExtenderAllowType = map[reflect.Kind]struct{}{ + reflect.Func: {}, reflect.Interface: {}, + reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, reflect.Struct: {}, + } + // DefaultHandlerExtenderFuncs 定义NewHandlerExtender默认注册的扩展函数。 + DefaultHandlerExtenderFuncs = []any{ + NewHandlerEmbed, + NewHandlerFunc, + NewHandlerFuncContextError, + NewHandlerFuncContextAnyError, + NewHandlerFuncContextRender, + NewHandlerFuncContextRenderError, + NewHandlerFuncError, + NewHandlerFuncRPC, + NewHandlerFuncRPCMap, + NewHandlerFuncRender, + NewHandlerFuncRenderError, + NewHandlerFuncString, + NewHandlerHTTP, + NewHandlerHTTPFileSystem, + NewHandlerHTTPFunc1, + NewHandlerHTTPFunc2, + NewHandlerHTTPHandler, + NewHandlerStringer, + } + // DefaultLoggerDepthMaxStack 定义GetCallerStacks函数默认显示栈最大层数。 + DefaultLoggerDepthMaxStack = 0x4f // DefaultLoggerNull 定义空日志输出器。 - DefaultLoggerNull = NewLoggerNull() - // DefaultLoggerSyncDuration 定义LoggerStd默认Sync写入日志间隔时间,在Mount时使用。 - DefaultLoggerSyncDuration = time.Millisecond * 80 - // DefaultLoggerTimeFormat 定义默认日志输出和contextBase.WriteError的时间格式 - DefaultLoggerTimeFormat = "2006-01-02 15:04:05" - // DefaultRenderFunc 定义默认使用的Render函数。 - DefaultRenderFunc = RenderJSON - // DefaultRenderHTMLTemplate 定义RenderHTML的默认通用模板。 - DefaultRenderHTMLTemplate *template.Template - // DefaultRouterAllMethod 定义路由器允许注册的全部方法,注册其他方法别忽略,前六种方法始终存在。 - DefaultRouterAllMethod = []string{MethodGet, MethodPost, MethodPut, MethodDelete, MethodHead, MethodPatch, MethodOptions, MethodConnect, MethodTrace} + DefaultLoggerNull = NewLoggerNull() + DefaultLoggerEnableHookFatal = false + DefaultLoggerEnableHookMeta = false + DefaultLoggerEnableStdColor = true + // DefaultLoggerEntryBufferLength 定义默认LoggerEntry缓冲长度。 + DefaultLoggerEntryBufferLength = 2048 + // DefaultLoggerEntryFieldsLength 定义默认LoggerEntry Field数量。 + DefaultLoggerEntryFieldsLength = 4 + // DefaultLoggerFormatter 定义Logger默认日志格式化格式。 + DefaultLoggerFormatter = "json" + // DefaultLoggerFormatterFormatTime 定义默认日志输出和contextBase.WriteError的时间格式。 + DefaultLoggerFormatterFormatTime = "2006-01-02 15:04:05.000" + // DefaultLoggerFormatterKeyLevel 定义默认Level字段输出名称。 + DefaultLoggerFormatterKeyLevel = "level" + // DefaultLoggerFormatterKeyMessage 定义默认Message字段输出名称。 + DefaultLoggerFormatterKeyMessage = "message" + // DefaultLoggerFormatterKeyTime 定义默认Time字段输出名称。 + DefaultLoggerFormatterKeyTime = "time" + // DefaultLoggerLevelStrings 定义日志级别输出字符串。 + DefaultLoggerLevelStrings = [...]string{"DEBUG", "INFO", "WARNING", "ERROR", "FATAL", "DISCARD"} + // DefaultLoggerWriterRotateDataKeys 定义日期滚动时/天/月/年的关键字,顺序不可变化。 + DefaultLoggerWriterRotateDataKeys = [...]string{"hh", "dd", "mm", "yyyy"} + // DefaultLoggerWriterStdoutWindowsColor 定义GOOS=windows时是否使用彩色level字段。 + DefaultLoggerWriterStdoutWindowsColor = false + // DefaultRouterAllMethod 定义路由器允许注册的全部方法,前六种方法在RouterCore始终存在。 + DefaultRouterAllMethod = []string{ + MethodGet, MethodPost, MethodPut, + MethodDelete, MethodHead, MethodPatch, + MethodOptions, MethodConnect, MethodTrace, + } // DefaultRouterAnyMethod 定义Any方法的注册使用的方法。 - DefaultRouterAnyMethod = []string{MethodGet, MethodPost, MethodPut, MethodDelete, MethodHead, MethodPatch} + DefaultRouterAnyMethod = append([]string{}, DefaultRouterAllMethod[0:6]...) + // DefaultRouterCoreMethod 定义routerCoreStd实现中默认存储的6种方法处理对象。 + DefaultRouterCoreMethod = append([]string{}, DefaultRouterAllMethod[0:6]...) + // DefaultRouterLoggerKind 定义默认RouterStd输出那些类型日志。 + DefaultRouterLoggerKind = "all" + // DefaultServerListen 定义ServerListenConfig使用Listen函数,用于hook listen。 + DefaultServerListen = net.Listen + DefaultServerReadTimeout = 60 * time.Second + DefaultServerReadHeaderTimeout = 60 * time.Second + DefaultServerWriteTimeout = 60 * time.Second + DefaultServerIdleTimeout = 60 * time.Second // DefaultServerShutdownWait 定义Server优雅退出等待时间。 DefaultServerShutdownWait = 30 * time.Second + // DefaultTemplateNameStaticIndex 定义默认渲染静态目录模板名称。 + DefaultTemplateNameStaticIndex = "eudore-embed-index" + // DefaultTemplateNameRenderData 定义默认RenderHTML模板名称。 + DefaultTemplateNameRenderData = "eudore-render-data" + // DefaultTemplateContentStaticIndex 定义默认渲染静态目录模板内容。 + DefaultTemplateContentStaticIndex = templateEmbedIndex + // DefaultTemplateContentRenderData 定义默认RenderHTML模板内容。 + DefaultTemplateContentRenderData = tempdateRenderData + // DefaultTemplateInit 定义App默认加载模板内容。 + DefaultTemplateInit = fmt.Sprintf(`{{- define "%s" -}}%s{{- end -}}{{- define "%s" -}}%s{{- end -}}`, + DefaultTemplateNameStaticIndex, DefaultTemplateContentStaticIndex, + DefaultTemplateNameRenderData, DefaultTemplateContentRenderData, + ) + // DefaultValueGetSetTags 定义Get/SetAny默认的tag。 + DefaultValueGetSetTags = []string{"alias"} + // DefaultValueParseTimeFormats 定义尝试解析的时间格式。 + DefaultValueParseTimeFormats = []string{ + "2006-01-02", + "20060102", + "15:04:05", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05Z07:00", + "2006-01-02T15:04:05.999999999Z07:00", + time.ANSIC, + time.UnixDate, + time.RubyDate, + time.RFC822, + time.RFC822Z, + time.RFC850, + time.RFC1123, + time.RFC1123Z, + time.RFC3339, + time.RFC3339Nano, + } + // DefaultValueParseTimeFixed 定义预定义时间格式长度是否固定,避免解析长度不相同的时间格式。 + DefaultValueParseTimeFixed = []bool{ + true, true, true, true, true, true, + false, false, true, true, true, true, true, true, true, true, + } + // DefaultDaemonPidfile 定义daemon默认使用的pid文件。 + DefaultDaemonPidfile = "/var/run/eudore.pid" + // DefaultGodocServer 定义应用默认使用的godoc服务器域名。 + DefaultGodocServer = "https://golang.org" // DefaultTraceServer 定义应用默认使用的jaeger链路追踪服务器域名。 DefaultTraceServer = "" - // DefaultValidateTag 定义NewValidateField获取校验规则的结构体tag。 - DefaultValidateTag = "validate" - // defaultRouterAnyMethod 定义routerStd默认存储的6种方法处理对象。 - defaultRouterAnyMethod = []string{MethodGet, MethodPost, MethodPut, MethodDelete, MethodHead, MethodPatch} - - // DefaultGetSetTags 定义Get/Set函数使用的默认tag。 - DefaultGetSetTags = []string{"alias"} - // DefaultConvertTags 定义默认转换使用的结构体tags。 - DefaultConvertTags = []string{"alias"} - - // 定义HandlerData返回error使用的Status和Code - - // StatucBindFail 定义Bind返回错误的状态码。 - StatucBindFail = 0 - // StatucValidateFail 定义Validate返回错误的状态码。 - StatucValidateFail = 0 - // StatucFilteFail 定义Filte返回错误的状态码。 - StatucFilteFail = 0 - // StatucRenderFail 定义Render返回错误的状态码。 - StatucRenderFail = 0 - // CodeBindFail 定义Bind返回错误的Code。 - CodeBindFail = 0 - // CodeValidateFail 定义Validate返回错误的Code。 - CodeValidateFail = 0 - // CodeFilteFail 定义Filte返回错误的Code。 - CodeFilteFail = 0 - // CodeRenderFail 定义Render返回错误的Code。 - CodeRenderFail = 0 -) - -// 定义各种类型的反射类型。 -var ( - typeBool = reflect.TypeOf((*bool)(nil)).Elem() - typeBytes = reflect.TypeOf((*[]byte)(nil)).Elem() - typeError = reflect.TypeOf((*error)(nil)).Elem() - typeInterface = reflect.TypeOf((*interface{})(nil)).Elem() - typeString = reflect.TypeOf((*string)(nil)).Elem() - - typeContext = reflect.TypeOf((*Context)(nil)).Elem() - typeController = reflect.TypeOf((*Controller)(nil)).Elem() - typeControllerName = reflect.TypeOf((*controllerName)(nil)).Elem() - typeHandlerFunc = reflect.TypeOf((*HandlerFunc)(nil)).Elem() - typeStringer = reflect.TypeOf((*fmt.Stringer)(nil)).Elem() - typeTimeTime = reflect.TypeOf((*time.Time)(nil)).Elem() -) - -// 检测各类接口 -var ( - _ Logger = (*LoggerStd)(nil) - _ LoggerStdData = (*loggerStdDataJSON)(nil) - _ LoggerStdData = (*loggerStdDataInit)(nil) - _ Config = (*configStd)(nil) - _ Client = (*clientStd)(nil) - _ Server = (*serverStd)(nil) - _ Server = (*serverFcgi)(nil) - _ Router = (*RouterStd)(nil) - _ RouterCore = (*routerCoreStd)(nil) - _ RouterCore = (*routerCoreDebug)(nil) - _ RouterCore = (*routerCoreHost)(nil) - _ RouterCore = (*routerCoreLock)(nil) - _ Context = (*contextBase)(nil) - _ ResponseWriter = (*responseWriterHTTP)(nil) - _ Controller = (*ControllerAutoRoute)(nil) - _ Controller = (*controllerError)(nil) - _ HandlerExtender = (*handlerExtendBase)(nil) - _ HandlerExtender = (*handlerExtendWarp)(nil) - _ HandlerExtender = (*handlerExtendTree)(nil) - _ FuncCreator = (*funcCreator)(nil) -) -// 定义默认错误 -var ( - // ErrConverterInputDataNil 在Converter方法时,输出参数是空。 - ErrConverterInputDataNil = errors.New("Converter input value is nil") - // ErrConverterInputDataNotPtr 在Converter方法时,输出参数是空。 - ErrConverterInputDataNotPtr = errors.New("Converter input value not is ptr") - // ErrConverterTargetDataNil 在Converter方法时,目标参数是空。 - ErrConverterTargetDataNil = errors.New("Converter target data is nil") + // ErrClientBodyFormNotGetBody 定义ClientBodyForm无法获取复制对象错误。 + ErrClientBodyFormNotGetBody = errors.New("client bodyForm contains files that cannot be copied, cannot copy body") + // ErrFuncCreatorNotFunc 定义FuncCreator无法获取或创建函数。 + ErrFuncCreatorNotFunc = errors.New("not found or create func") + // ErrHandlerExtenderParamNotFunc 定义调用RegisterHandlerExtender函数时,参数必须是一个函数。 + ErrHandlerExtenderParamNotFunc = errors.New("the parameter type of RegisterNewHandler must be a function") // ErrLoggerLevelUnmarshalText 日志级别解码错误,请检查输出的[]byte是否有效。 ErrLoggerLevelUnmarshalText = errors.New("logger level UnmarshalText error") - ErrRenderHandlerSkip = errors.New("render hander skip") - // ErrRegisterNewHandlerParamNotFunc 调用RegisterHandlerExtend函数时,参数必须是一个函数。 - ErrRegisterNewHandlerParamNotFunc = errors.New("The parameter type of RegisterNewHandler must be a function") - // ErrResponseWriterHTTPNotHijacker ResponseWriterHTTP对象没有实现http.Hijacker接口。 - ErrResponseWriterHTTPNotHijacker = errors.New("http.Hijacker interface is not supported") - // ErrSeterNotSupportField Seter对象不支持设置当前属性。 - ErrSeterNotSupportField = errors.New("Converter seter not support set field") - // ErrMiddlewareRequestEntityTooLarge middleware/BodyLimit 分段请求body读取时限制长队返回错误。 - ErrMiddlewareRequestEntityTooLarge = errors.New("Request Entity Too Large") + // ErrRenderHandlerSkip 定义Renders执行Render时无法渲染,跳过当前Render。 + ErrRenderHandlerSkip = errors.New("render hander skip") + // ErrResponseWriterNotHijacker ResponseWriterHTTP对象没有实现http.Hijacker接口。 + ErrResponseWriterNotHijacker = errors.New("http.Hijacker interface is not supported") + // ErrValueInputDataNil 在Converter方法时,输出参数是空。 + ErrValueInputDataNil = errors.New("converter input value is nil") + // ErrValueInputDataNotPtr 在Converter方法时,输出参数是空。 + ErrValueInputDataNotPtr = errors.New("converter input value not is ptr") // ErrFormatBindDefaultNotSupportContentType BindDefault函数不支持当前的Content-Type Header。 - ErrFormatBindDefaultNotSupportContentType = "BindDefault not support content type header: %s" - // ErrFormatConverterGet 在Get方法路径查找返回错误。 - ErrFormatConverterGet = "Get path '%s' error: %s" - // ErrFormatConverterNotCanset 在Set方法时,结构体不支持该项属性。 - ErrFormatConverterNotCanset = "The attribute '%s' of structure %s is not set, please use public field" - // ErrFormatConverterSetArrayIndexInvalid 在Set方法时,设置数组的索引的无效 - ErrFormatConverterSetArrayIndexInvalid = "the Set function obtained array index '%s' is invalid, array len is %d" - // ErrFormatConverterSetStringUnknownType setWithString函数遇到未定义的反射类型 - ErrFormatConverterSetStringUnknownType = "setWithString unknown type %s" - // ErrFormatConverterSetStructNotField 在Set时,结构体没有当前属性。 - ErrFormatConverterSetStructNotField = "Setting the structure has no attribute '%s', or this attribute is not exportable" - // ErrFormatConverterSetTypeError 在Set时,类型异常,无法继续设置值。 - ErrFormatConverterSetTypeError = "The type of the set value is %s, which is not configurable, key: %v, val: %s" - // ErrFormatConverterSetWithValue setWithValue函数中类型无法赋值。 - ErrFormatConverterSetWithValue = "The setWithValue method type %s cannot be assigned to type %s" - // ErrFormatRegisterHandlerExtendInputParamError RegisterHandlerExtend函数注册的函数参数错误。 - ErrFormatRegisterHandlerExtendInputParamError = "The '%s' input parameter is illegal and should be one func/interface/ptr/struct" - // ErrFormatRegisterHandlerExtendOutputParamError RegisterHandlerExtend函数注册的函数返回值错误。 - ErrFormatRegisterHandlerExtendOutputParamError = "The '%s' output parameter is illegal and should be a HandlerFunc object" - // ErrFormatRouterStdAddController RouterStd控制器路由注入错误 - ErrFormatRouterStdAddController = "The RouterStd.AddController Inject %s error: %v" - // ErrFormatRouterStdAddHandlerExtend RouterStd添加扩展处理函数错误 - ErrFormatRouterStdAddHandlerExtend = "The RouterStd.AddHandlerExtend path is '%s' RegisterHandlerExtend error: %v" - // ErrFormatRouterStdRegisterHandlersMethodInvalid RouterStd.registerHandlers 的添加的是无效的,全部有效方法为RouterAnyMethod。 - ErrFormatRouterStdRegisterHandlersMethodInvalid = "The RouterStd.registerHandlers arg method '%s' is invalid, complete method: '%s', add fullpath: '%s'" - // ErrFormatRouterStdRegisterHandlersRecover RouterStd注册路由时恢复panic。 - ErrFormatRouterStdRegisterHandlersRecover = "The RouterStd.registerHandlers arg method is '%s' and path is '%s', recover error: %v" - // ErrFormatRouterStdNewHandlerFuncsUnregisterType RouterStd添加处理对象或中间件的第n个参数类型未注册,需要先使用RegisterHandlerExtend或AddHandlerExtend注册该函数类型。 - ErrFormatRouterStdNewHandlerFuncsUnregisterType = "The RouterStd.newHandlerFuncs path is '%s', %dth handler parameter type is '%s', this is the unregistered handler type" - // ErrFormatProtobufDecodeNilInteface 定义protobuf解码到空接口 + ErrFormatBindDefaultNotSupportContentType = "BindDefault: not support content type header: %s" + // ErrFormatClintCheckStatusError 定义Client检查status不匹配错误。 + ErrFormatClintCheckStatusError = "clint check status is %d not in %v" + // ErrFormatClintParseBodyError 定义Client解析Body时无法解析Content-Type错误。 + ErrFormatClintParseBodyError = "eudore client parse not suppert Content-Type: %s" + // ErrFormatContextParseFormNotSupportContentType Context解析Form时时,不支持Content-Type。 + ErrFormatContextParseFormNotSupportContentType = "eudore.Context: parse form not supported Content-Type: %s" + // ErrFormatContextRedirectInvalid Context.Redirect方法使用了无效的状态码。 + ErrFormatContextRedirectInvalid = "eudore.Context: invalid redirect status code %d" + // ErrFormatContextPushFailed Context.Push方法推送资源错误。 + ErrFormatContextPushFailed = "eudore.Context: push resource %s failed: %w" + // ErrFormatFuncCreatorRegisterInvalidType fc注册函数类似是无效的。 + ErrFormatFuncCreatorRegisterInvalidType = "Register func '%s' type is %T, must 'func(T) bool' or 'func(string) (func(T) bool, error)'" + // ErrFormatHandlerExtenderInputParamError RegisterHandlerExtender函数注册的函数参数错误。 + ErrFormatHandlerExtenderInputParamError = "the '%s' input parameter is illegal and should be one func/interface/ptr/struct" + // ErrFormatHandlerExtenderOutputParamError RegisterHandlerExtender函数注册的函数返回值错误。 + ErrFormatHandlerExtenderOutputParamError = "the '%s' output parameter is illegal and should be a HandlerFunc object" + // ErrFormatRouterStdAddController RouterStd控制器路由注入错误。 + ErrFormatRouterStdAddController = "the RouterStd.AddController Inject %s error: %w" + // ErrFormatRouterStdAddHandlerExtender RouterStd添加扩展处理函数错误。 + ErrFormatRouterStdAddHandlerExtender = "the RouterStd.AddHandlerExtender path is '%s' RegisterHandlerExtender error: %w" + // ErrFormatRouterStdaddHandlerMethodInvalid RouterStd.addHandler 的添加的是无效的,全部有效方法为RouterAnyMethod。 + ErrFormatRouterStdAddHandlerMethodInvalid = "the RouterStd.addHandler arg method '%s' is invalid, add fullpath: '%s'" + // ErrFormatRouterStdaddHandlerRecover RouterStd注册路由时恢复panic。 + ErrFormatRouterStdAddHandlerRecover = "the RouterStd.addHandler arg method is '%s' and path is '%s', recover error: %w" + // ErrFormarRouterStdLoadInvalidFunc RouterStd无法加载路径对应的校验函数。 + ErrFormarRouterStdLoadInvalidFunc = "loadCheckFunc path is invalid, load path '%s' error: %v " + // ErrFormatRouterStdNewHandlerFuncsUnregisterType RouterStd添加处理对象或中间件的第n个参数类型未注册,需要先使用RegisterHandlerExtender或AddHandlerExtender注册该函数类型。 + ErrFormatRouterStdNewHandlerFuncsUnregisterType = "the RouterStd.newHandlerFuncs path is '%s', %dth handler parameter type is '%s', this is the unregistered handler type" + // ErrFormatProtobufDecodeNilInteface 定义protobuf解码到空接口。 ErrFormatProtobufDecodeNilInteface = "protobuf decode %s interface %s is nil" ErrFormatProtobufDecodeInvalidFlag = "protobuf decode %s invalid flag %d" ErrFormatProtobufDecodeInvalidKind = "protobuf decode %s invalid kind %s" @@ -235,86 +354,90 @@ var ( ErrFormatProtobufDecodeReadInvalid = "protobuf decode %s read length %d invalid has data %d" ErrFormatProtobufDecodeMessageNotRead = "protobuf decode message has %d not read" ErrFormatProtobufTypeMustSturct = "protobuf encdoe/decode kind must struct, current type %s" - // ErrFormatMiddlewareRequestEntityTooLargeSzie BodyLimit请求长度超过限制。 - ErrFormatMiddlewareRequestEntityTooLargeSzie = "Request Entity Too Large, limit body size %d" - // ErrFormarRouterStdLoadInvalidFunc RouterStd无法加载路径对应的校验函数。 - ErrFormarRouterStdLoadInvalidFunc = "loadCheckFunc path is invalid, load path '%s' error: %v " - // ErrFormatParseValidateFieldError Validate解析结构体规则错误。 - ErrFormatParseValidateFieldError = "validateField %s.%s parse field %s create rule %s error: %s" - // ErrFormatFuncCreatorRegisterInvalidType fc注册函数类似是无效的。 - ErrFormatFuncCreatorRegisterInvalidType = "Register func %s type is %T, must 'func(T) bool' or 'func(string) (func(T) bool, error)'" - // ErrFormatFuncCreatorNotFunc 无法创建对应的校验函数。 - ErrFormatFuncCreatorNotFunc = "not found or create func %s" + // ErrFormatParseValidateFieldError 定义Validate校验失败时输出Error格式。 + ErrFormatValidateErrorFormat = "validate %s.%s field %s check %s rule fatal, value: %%#v" + // ErrFormatValidateParseFieldError Validate解析结构体规则错误。 + ErrFormatValidateParseFieldError = "validateField %s.%s parse field %s create rule %s error: %s" + // ErrFormatValueError 定义Value操作错误。 + ErrFormatValueError = "value %s path '%s' error: %w" + // ErrFormatValueTypeNil 定义Value对象为空。 + ErrFormatValueTypeNil = "is nil" + ErrFormatValueAnonymousField = " is anonymous field" + ErrFormatValueNotField = "not found field '%s'" + ErrFormatValueArrayIndexInvalid = "parse index '%s' is invalid, length is %d" + ErrFormatValueMapIndexInvalid = "parse index '%s' is invalid" + ErrFormatValueMapValueInvalid = "obtained index '%s' value is invalid" + ErrFormatValueStructUnexported = "field '%s' is unexported" + ErrFormatValueStructNotCanset = "field '%s' is not canset " + // ErrFormatConverterSetStringUnknownType setWithString函数遇到未定义的反射类型。 + ErrFormatValueSetStringUnknownType = "setWithString unknown type %s" + // ErrFormatConverterSetWithValue setWithValue函数中类型无法赋值。 + ErrFormatValueSetWithValue = "the setWithValue method type %s cannot be assigned to type %s" ) // 定义eudore定义各种常量。 const ( - // Eudore environ - - // EnvEudoreIsDaemon 用于表示是否fork后台启动。 - EnvEudoreIsDaemon = "EUDORE_IS_DEAMON" - // EnvEudoreIsNotify 表示使用使用了Notify组件。 - EnvEudoreIsNotify = "EUDORE_IS_NOTIFY" - // EnvEudoreDisablePidfile 用于Command组件不写入pidfile,Notify组件启动的子程序不写入pidfile。 - EnvEudoreDisablePidfile = "EUDORE_DISABLE_PIDFILE" + // EnvEudoreListeners 定义启动fd的地址。 + EnvEudoreDaemonListeners = "EUDORE_DAEMON_LISTENERS" + // EnvEudoreDaemonRestartID 定义重启时父进程的pid,由子进程kill。 + EnvEudoreDaemonRestartID = "EUDORE_DAEMON_RESTART_ID" + // EnvEudoreDaemonEnable 用于表示是否fork后台启动,会禁用Logger stdout输出。 + EnvEudoreDaemonEnable = "EUDORE_DAEMON_ENABLE" + // EnvEudoreDaemonTimeout 定义daemon等待restart和stop命令完成的超时秒数。 + EnvEudoreDaemonTimeout = "EUDORE_DAEMON_TIMEOUT" - // Response statue - - StatusContinue = 100 // RFC 7231, 6.2.1 - StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2 - StatusProcessing = 102 // RFC 2518, 10.1 - - StatusOK = 200 // RFC 7231, 6.3.1 - StatusCreated = 201 // RFC 7231, 6.3.2 - StatusAccepted = 202 // RFC 7231, 6.3.3 - StatusNonAuthoritativeInfo = 203 // RFC 7231, 6.3.4 - StatusNoContent = 204 // RFC 7231, 6.3.5 - StatusResetContent = 205 // RFC 7231, 6.3.6 - StatusPartialContent = 206 // RFC 7233, 4.1 - StatusMultiStatus = 207 // RFC 4918, 11.1 - StatusAlreadyReported = 208 // RFC 5842, 7.1 - StatusIMUsed = 226 // RFC 3229, 10.4.1 - - StatusMultipleChoices = 300 // RFC 7231, 6.4.1 - StatusMovedPermanently = 301 // RFC 7231, 6.4.2 - StatusFound = 302 // RFC 7231, 6.4.3 - StatusSeeOther = 303 // RFC 7231, 6.4.4 - StatusNotModified = 304 // RFC 7232, 4.1 - StatusUseProxy = 305 // RFC 7231, 6.4.5 - - StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7 - StatusPermanentRedirect = 308 // RFC 7538, 3 - - StatusBadRequest = 400 // RFC 7231, 6.5.1 - StatusUnauthorized = 401 // RFC 7235, 3.1 - StatusPaymentRequired = 402 // RFC 7231, 6.5.2 - StatusForbidden = 403 // RFC 7231, 6.5.3 - StatusNotFound = 404 // RFC 7231, 6.5.4 - StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5 - StatusNotAcceptable = 406 // RFC 7231, 6.5.6 - StatusProxyAuthRequired = 407 // RFC 7235, 3.2 - StatusRequestTimeout = 408 // RFC 7231, 6.5.7 - StatusConflict = 409 // RFC 7231, 6.5.8 - StatusGone = 410 // RFC 7231, 6.5.9 - StatusLengthRequired = 411 // RFC 7231, 6.5.10 - StatusPreconditionFailed = 412 // RFC 7232, 4.2 - StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11 - StatusRequestURITooLong = 414 // RFC 7231, 6.5.12 - StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13 - StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4 - StatusExpectationFailed = 417 // RFC 7231, 6.5.14 - StatusTeapot = 418 // RFC 7168, 2.3.3 - StatusMisdirectedRequest = 421 // RFC 7540, 9.1.2 - StatusUnprocessableEntity = 422 // RFC 4918, 11.2 - StatusLocked = 423 // RFC 4918, 11.3 - StatusFailedDependency = 424 // RFC 4918, 11.4 - StatusTooEarly = 425 // RFC 8470, 5.2. - StatusUpgradeRequired = 426 // RFC 7231, 6.5.15 - StatusPreconditionRequired = 428 // RFC 6585, 3 - StatusTooManyRequests = 429 // RFC 6585, 4 - StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5 - StatusUnavailableForLegalReasons = 451 // RFC 7725, 3 + // Status. + StatusContinue = 100 // RFC 7231, 6.2.1 + StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2 + StatusProcessing = 102 // RFC 2518, 10.1 + StatusOK = 200 // RFC 7231, 6.3.1 + StatusCreated = 201 // RFC 7231, 6.3.2 + StatusAccepted = 202 // RFC 7231, 6.3.3 + StatusNonAuthoritativeInfo = 203 // RFC 7231, 6.3.4 + StatusNoContent = 204 // RFC 7231, 6.3.5 + StatusResetContent = 205 // RFC 7231, 6.3.6 + StatusPartialContent = 206 // RFC 7233, 4.1 + StatusMultiStatus = 207 // RFC 4918, 11.1 + StatusAlreadyReported = 208 // RFC 5842, 7.1 + StatusIMUsed = 226 // RFC 3229, 10.4.1 + StatusMultipleChoices = 300 // RFC 7231, 6.4.1 + StatusMovedPermanently = 301 // RFC 7231, 6.4.2 + StatusFound = 302 // RFC 7231, 6.4.3 + StatusSeeOther = 303 // RFC 7231, 6.4.4 + StatusNotModified = 304 // RFC 7232, 4.1 + StatusUseProxy = 305 // RFC 7231, 6.4.5 + StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7 + StatusPermanentRedirect = 308 // RFC 7538, 3 + StatusBadRequest = 400 // RFC 7231, 6.5.1 + StatusUnauthorized = 401 // RFC 7235, 3.1 + StatusPaymentRequired = 402 // RFC 7231, 6.5.2 + StatusForbidden = 403 // RFC 7231, 6.5.3 + StatusNotFound = 404 // RFC 7231, 6.5.4 + StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5 + StatusNotAcceptable = 406 // RFC 7231, 6.5.6 + StatusProxyAuthRequired = 407 // RFC 7235, 3.2 + StatusRequestTimeout = 408 // RFC 7231, 6.5.7 + StatusConflict = 409 // RFC 7231, 6.5.8 + StatusGone = 410 // RFC 7231, 6.5.9 + StatusLengthRequired = 411 // RFC 7231, 6.5.10 + StatusPreconditionFailed = 412 // RFC 7232, 4.2 + StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11 + StatusRequestURITooLong = 414 // RFC 7231, 6.5.12 + StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13 + StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4 + StatusExpectationFailed = 417 // RFC 7231, 6.5.14 + StatusTeapot = 418 // RFC 7168, 2.3.3 + StatusMisdirectedRequest = 421 // RFC 7540, 9.1.2 + StatusUnprocessableEntity = 422 // RFC 4918, 11.2 + StatusLocked = 423 // RFC 4918, 11.3 + StatusFailedDependency = 424 // RFC 4918, 11.4 + StatusTooEarly = 425 // RFC 8470, 5.2. + StatusUpgradeRequired = 426 // RFC 7231, 6.5.15 + StatusPreconditionRequired = 428 // RFC 6585, 3 + StatusTooManyRequests = 429 // RFC 6585, 4 + StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5 + StatusUnavailableForLegalReasons = 451 // RFC 7725, 3 StatusInternalServerError = 500 // RFC 7231, 6.6.1 StatusNotImplemented = 501 // RFC 7231, 6.6.2 StatusBadGateway = 502 // RFC 7231, 6.6.3 @@ -326,8 +449,9 @@ const ( StatusLoopDetected = 508 // RFC 5842, 7.2 StatusNotExtended = 510 // RFC 2774, 7 StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6 + StauusClientClosedRequest = 499 // nginx - // Header + // Header. HeaderAccept = "Accept" HeaderAcceptCharset = "Accept-Charset" @@ -384,6 +508,7 @@ const ( HeaderPragma = "Pragma" HeaderProxyAuthenticate = "Proxy-Authenticate" HeaderProxyAuthorization = "Proxy-Authorization" + HeaderProxyConnection = "Proxy-Connection" HeaderPublicKeyPins = "Public-Key-Pins" HeaderPublicKeyPinsReportOnly = "Public-Key-Pins-Report-Only" HeaderRange = "Range" @@ -420,9 +545,10 @@ const ( HeaderXRequestID = "X-Request-Id" HeaderXTraceID = "X-Trace-Id" HeaderXEudoreAdmin = "X-Eudore-Admin" + HeaderXEudoreCache = "X-Eudore-Cache" HeaderXEudoreRoute = "X-Eudore-Route" - // default http method by rfc2616 + // default http method by rfc2616. MethodAny = "ANY" MethodGet = "GET" @@ -435,44 +561,178 @@ const ( MethodConnect = "CONNECT" MethodTrace = "TRACE" - // Mime + // Mime. - MimeCharsetUtf8 = "charset=utf-8" - MimeMultipartForm = "multipart/form-data" MimeText = "text/*" MimeTextPlain = "text/plain" - MimeTextPlainCharsetUtf8 = MimeTextPlain + "; " + MimeCharsetUtf8 MimeTextMarkdown = "text/markdown" - MimeTextMarkdownCharsetUtf8 = MimeTextMarkdown + "; " + MimeCharsetUtf8 MimeTextJavascript = "text/javascript" - MimeTextJavascriptCharsetUtf8 = MimeTextJavascript + "; " + MimeCharsetUtf8 MimeTextHTML = "text/html" - MimeTextHTMLCharsetUtf8 = MimeTextHTML + "; " + MimeCharsetUtf8 MimeTextCSS = "text/css" - MimeTextCSSCharsetUtf8 = MimeTextCSS + "; " + MimeCharsetUtf8 MimeTextXML = "text/xml" - MimeTextXMLCharsetUtf8 = MimeTextXML + "; " + MimeCharsetUtf8 - MimeApplicationYAMLCharsetUtf8 = MimeApplicationYAML + "; " + MimeCharsetUtf8 MimeApplicationYAML = "application/yaml" - MimeApplicationXMLCharsetUtf8 = MimeApplicationXML + "; " + MimeCharsetUtf8 MimeApplicationXML = "application/xml" MimeApplicationProtobuf = "application/protobuf" - MimeApplicationJSONCharsetUtf8 = MimeApplicationJSON + "; " + MimeCharsetUtf8 MimeApplicationJSON = "application/json" - MimeApplicationFormCharsetUtf8 = MimeApplicationForm + "; " + MimeCharsetUtf8 MimeApplicationForm = "application/x-www-form-urlencoded" - - // Param + MimeApplicationOctetStream = "application/octet-stream" + MimeMultipartForm = "multipart/form-data" + MimeMultipartMixed = "multipart/mixed" + MimeCharsetUtf8 = "charset=utf-8" + MimeTextPlainCharsetUtf8 = MimeTextPlain + "; " + MimeCharsetUtf8 + MimeTextMarkdownCharsetUtf8 = MimeTextMarkdown + "; " + MimeCharsetUtf8 + MimeTextJavascriptCharsetUtf8 = MimeTextJavascript + "; " + MimeCharsetUtf8 + MimeTextHTMLCharsetUtf8 = MimeTextHTML + "; " + MimeCharsetUtf8 + MimeTextCSSCharsetUtf8 = MimeTextCSS + "; " + MimeCharsetUtf8 + MimeTextXMLCharsetUtf8 = MimeTextXML + "; " + MimeCharsetUtf8 + MimeApplicationYAMLCharsetUtf8 = MimeApplicationYAML + "; " + MimeCharsetUtf8 + MimeApplicationXMLCharsetUtf8 = MimeApplicationXML + "; " + MimeCharsetUtf8 + MimeApplicationJSONCharsetUtf8 = MimeApplicationJSON + "; " + MimeCharsetUtf8 + MimeApplicationFormCharsetUtf8 = MimeApplicationForm + "; " + MimeCharsetUtf8 + // Param. ParamAction = "action" ParamAllow = "allow" + ParamAutoIndex = "autoindex" ParamBasicAuth = "basicauth" ParamCaller = "caller" ParamControllerGroup = "controllergroup" + ParamDepth = "depth" + ParamLoggerKind = "loggerkind" + ParamPrefix = "prefix" ParamRegister = "register" ParamTemplate = "template" ParamRoute = "route" ParamUserid = "Userid" + ParamUsername = "Username" ParamPolicy = "Policy" ParamResource = "Resource" ) + +var ( + templateEmbedIndex = ` + + + + + Index of {{.Path}} + + +

Index of {{.Path}} {{if .Upload}}{{end}}

+{{- if ne .Path "/"}}{{end}} + + + + {{- range $index, $file := .Files}}{{if $file.IsDir}} + + {{- else }} + + {{- end }}{{end}} + +
NameSizeDate Modified
{{$file.Name}}/{{$file.ModTime}}
{{$file.Name}}{{$file.SizeFormat}}{{$file.ModTime}}
` + tempdateRenderData = ` + + + Eudore Render + + + + + + + +
+
+ General +
Request URL: {{.Host}}{{.Path}}
+
Request Method: {{.Method}}
+
Status Code: {{.Status}}
+
Remote Address: {{.RemoteAddr}}
+
Local Address: {{.LocalAddr}}
+
+ {{- if ne (len .Query) 0 }} +
+ Requesst Querys + {{- range $key, $vals := .Query -}} + {{- range $i, $val := $vals }} +
{{$key}}: {{$val}}
+ {{- end }} + {{- end }} +
+ {{- end }} +
+ Requesst Params + {{- $iskey := true }} + {{- range $i,$val := .Params}} + {{- if $iskey}} +
{{$val}}: {{- else}}{{$val}}
{{end}} + {{- $iskey = not $iskey}} + {{- end}} +
+
+ Request Headers + {{- range $key, $vals := .RequestHeader -}} + {{- range $i, $val := $vals }} +
{{$key}}: {{$val}}
+ {{- end }} + {{- end }} +
+
+ {{- $trace := .TraceServer }} + Response Headers + {{- range $key, $vals := .ResponseHeader -}} + {{- range $i, $val := $vals }} + {{- if and (eq $key "X-Trace-Id") (ne $trace "")}} +
{{$key}}: {{$val}}
+ {{- else }} +
{{$key}}: {{$val}}
+ {{- end }} + {{- end }} + {{- end }} +
+
+ Response Data +
{{.Data}}
+
+
+ +` +) diff --git a/context.go b/context.go index 35d4457..eafc669 100644 --- a/context.go +++ b/context.go @@ -8,11 +8,14 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" + "mime" "mime/multipart" "net" "net/http" + "net/textproto" "net/url" + "reflect" "strings" "sync" "time" @@ -20,6 +23,7 @@ import ( /* Context 定义请求上下文接口,分为请求上下文数据、请求、参数、响应、日志输出五部分。 + context.Context、eudore.ResponseWriter、*http.Request、eudore.Logger对象读写 中间件机制执行 基本请求信息 @@ -36,11 +40,11 @@ type Context interface { GetContext() context.Context Request() *http.Request Response() ResponseWriter - Value(interface{}) interface{} + Value(any) any SetContext(context.Context) SetRequest(*http.Request) SetResponse(ResponseWriter) - SetValue(interface{}, interface{}) + SetValue(any, any) // handles SetHandler(int, HandlerFuncs) GetHandler() (int, HandlerFuncs) @@ -58,7 +62,7 @@ type Context interface { ContentType() string Istls() bool Body() []byte - Bind(interface{}) error + Bind(any) error // param query header cookie form Params() *Params @@ -70,7 +74,7 @@ type Context interface { SetHeader(string, string) Cookies() []Cookie GetCookie(string) string - SetCookie(*SetCookie) + SetCookie(*CookieSet) SetCookieValue(string, string, int) FormValue(string) string FormValues() map[string][]string @@ -79,38 +83,38 @@ type Context interface { // response Write([]byte) (int, error) + WriteString(string) (int, error) WriteHeader(int) Redirect(int, string) Push(string, *http.PushOptions) error - Render(interface{}) error - WriteString(string) error - WriteFile(string) error + Render(any) error + WriteFile(string) // Database interface - Query(interface{}, DatabaseStmt) error + Query(any, DatabaseStmt) error Exec(DatabaseStmt) error - NewRequest(string, string, ...interface{}) error + NewRequest(string, string, ...any) error // Logger interface - Debug(...interface{}) - Info(...interface{}) - Warning(...interface{}) - Error(...interface{}) - Fatal(...interface{}) - Debugf(string, ...interface{}) - Infof(string, ...interface{}) - Warningf(string, ...interface{}) - Errorf(string, ...interface{}) - Fatalf(string, ...interface{}) - WithField(string, interface{}) Logger - WithFields([]string, []interface{}) Logger + Debug(...any) + Info(...any) + Warning(...any) + Error(...any) + Fatal(...any) + Debugf(string, ...any) + Infof(string, ...any) + Warningf(string, ...any) + Errorf(string, ...any) + Fatalf(string, ...any) + WithField(string, any) Logger + WithFields([]string, []any) Logger } // contextBase 实现Context接口。 type contextBase struct { // context index int - handler HandlerFuncs - httpParams Params + handlers HandlerFuncs + params Params config *contextBaseConfig RequestReader *http.Request ResponseWriter ResponseWriter @@ -123,22 +127,23 @@ type contextBase struct { } type contextBaseConfig struct { - Logger Logger - Database Database - Client Client - Bind func(Context, interface{}) error - Validate func(Context, interface{}) error - Filte func(Context, interface{}) error - Render func(Context, interface{}) error + Logger Logger + Database Database + Client Client + Bind func(Context, any) error + Validater func(Context, any) error + Filter func(Context, any) error + Render func(Context, any) error + DatabaseRuntime func(Context, DatabaseStmt) DatabaseStmt } type contextBaseValue struct { context.Context - Logger Logger - Database Database - Client Client - Error error - Values []interface{} + Logger + Database + Client + Error error + Values []any } // ResponseWriter 接口用于写入http请求响应体status、header、body。 @@ -146,30 +151,32 @@ type contextBaseValue struct { // net/http.response实现了flusher、hijacker、pusher接口。 type ResponseWriter interface { // http.ResponseWriter - Header() http.Header Write([]byte) (int, error) + WriteString(string) (int, error) WriteHeader(int) + Header() http.Header // http.Flusher Flush() // http.Hijacker Hijack() (net.Conn, *bufio.ReadWriter, error) // http.Pusher Push(string, *http.PushOptions) error + Size() int Status() int } -// responseWriterHTTP 是对net/http.ResponseWriter接口封装 +// responseWriterHTTP 是对net/http.ResponseWriter接口封装。 type responseWriterHTTP struct { http.ResponseWriter code int size int } -// SetCookie 定义响应返回的set-cookie header的数据生成 -type SetCookie = http.Cookie +// CookieSet 定义响应设置的set-cookie header的数据生成。 +type CookieSet = http.Cookie -// Cookie 定义请求读取的cookie header的键值对数据存储 +// Cookie 定义请求读取的cookie header的键值对数据存储。 type Cookie struct { Name string Value string @@ -178,7 +185,8 @@ type Cookie struct { // contextBaseEntry 实现ContextBase使用的Logger对象。 type contextBaseEntry struct { Logger - Context *contextBase + writeError func(error) + Context *contextBase } // NewContextBasePool 函数从上下文创建一个Context sync.Pool。 @@ -189,10 +197,10 @@ type contextBaseEntry struct { func NewContextBasePool(ctx context.Context) *sync.Pool { config := newContextBaseConfig(ctx) return &sync.Pool{ - New: func() interface{} { + New: func() any { return &contextBase{ - config: config, - httpParams: Params{ParamRoute, ""}, + config: config, + params: Params{ParamRoute, ""}, } }, } @@ -203,17 +211,18 @@ func NewContextBaseFunc(ctx context.Context) func() Context { config := newContextBaseConfig(ctx) return func() Context { return &contextBase{ - config: config, - httpParams: Params{ParamRoute, ""}, + config: config, + params: Params{ParamRoute, ""}, } } } func newContextBaseConfig(ctx context.Context) *contextBaseConfig { - bind, _ := ctx.Value(ContextKeyBind).(func(Context, interface{}) error) - validate, _ := ctx.Value(ContextKeyValidate).(func(Context, interface{}) error) - filte, _ := ctx.Value(ContextKeyFilte).(func(Context, interface{}) error) - render, _ := ctx.Value(ContextKeyRender).(func(Context, interface{}) error) + bind, _ := ctx.Value(ContextKeyBind).(func(Context, any) error) + validater, _ := ctx.Value(ContextKeyValidater).(func(Context, any) error) + filter, _ := ctx.Value(ContextKeyFilter).(func(Context, any) error) + render, _ := ctx.Value(ContextKeyRender).(func(Context, any) error) + db, _ := ctx.Value(ContextKeyDatabaseRuntime).(func(Context, DatabaseStmt) DatabaseStmt) if bind == nil { bind = NewBinds(nil) } @@ -221,23 +230,24 @@ func newContextBaseConfig(ctx context.Context) *contextBaseConfig { render = NewRenders(nil) } return &contextBaseConfig{ - Logger: ctx.Value(ContextKeyApp).(Logger).WithField("logger", true), - Database: ctx.Value(ContextKeyApp).(Database), - Client: ctx.Value(ContextKeyApp).(Client), - Bind: bind, - Validate: validate, - Filte: filte, - Render: render, + Logger: ctx.Value(ContextKeyApp).(Logger), + Database: ctx.Value(ContextKeyApp).(Database), + Client: ctx.Value(ContextKeyApp).(Client), + Bind: bind, + Validater: validater, + Filter: filter, + Render: render, + DatabaseRuntime: db, } } -// Reset Context +// Reset 函数重置Context数据。 func (ctx *contextBase) Reset(w http.ResponseWriter, r *http.Request) { ctx.context = &ctx.contextValues ctx.ResponseWriter = &ctx.httpResponse ctx.RequestReader = r - ctx.httpParams = ctx.httpParams[0:2] - ctx.httpParams[1] = "" + ctx.params = ctx.params[0:2] + ctx.params[1] = "" // cookies body ctx.contextValues.Reset(r.Context(), ctx.config) ctx.httpResponse.Reset(w) @@ -253,7 +263,7 @@ func (ctx *contextBase) GetContext() context.Context { } // Request 获取请求对象。 -// 注意:ctx.Request().Context() 不等于ctx.GetContext() +// 注意:ctx.Request().Context() 不等于ctx.GetContext()。 func (ctx *contextBase) Request() *http.Request { return ctx.RequestReader } @@ -263,7 +273,7 @@ func (ctx *contextBase) Response() ResponseWriter { return ctx.ResponseWriter } -func (ctx *contextBase) Value(key interface{}) interface{} { +func (ctx *contextBase) Value(key any) any { return ctx.contextValues.Value(key) } @@ -281,28 +291,28 @@ func (ctx *contextBase) SetResponse(w ResponseWriter) { ctx.ResponseWriter = w } -// SetLogger 方法设置ContextBases输出日志的基础Logger。 +// SetValue 方法设置内置context.Context的Value,可以调用Value方法读取。 // -// 注意确保设置的是Logger,而不是一个Entry。 -func (ctx *contextBase) SetValue(key, val interface{}) { +// 注意:如果设置Logger时确保设置的是Logger,而不是一个Entry。 +func (ctx *contextBase) SetValue(key, val any) { ctx.contextValues.SetValue(key, val) } // SetHandler 方法设置请求上下文的全部请求处理者。 func (ctx *contextBase) SetHandler(index int, hs HandlerFuncs) { - ctx.index, ctx.handler = index, hs + ctx.index, ctx.handlers = index, hs } // GetHandler 方法获取请求上下文的当前处理索引和全部请求处理者。 func (ctx *contextBase) GetHandler() (int, HandlerFuncs) { - return ctx.index, ctx.handler + return ctx.index, ctx.handlers } // Next 方法调用请求上下文下一个处理函数。 func (ctx *contextBase) Next() { ctx.index++ - for ctx.index < len(ctx.handler) { - ctx.handler[ctx.index](ctx) + for ctx.index < len(ctx.handlers) { + ctx.handlers[ctx.index](ctx) ctx.index++ } } @@ -313,7 +323,7 @@ func (ctx *contextBase) End() { ctx.httpResponse.writeStatus() } -// Err 方法返回 +// Err 方法返回请求上下文取消或处理的错误。 func (ctx *contextBase) Err() error { return ctx.contextValues.Err() } @@ -328,7 +338,7 @@ func (ctx *contextBase) Host() string { return ctx.RequestReader.Host } -// Method 方法返回请求方法, +// Method 方法返回请求方法。 func (ctx *contextBase) Method() string { return ctx.RequestReader.Method } @@ -339,12 +349,15 @@ func (ctx *contextBase) Path() string { } // RealIP 获取用户真实ip,ctx.Request().RemoteAddr()获取远程连接地址。 +// +// 如果server不存在前置代理层直接对外, +// 需要添加中间件过滤请求header X-Real-Ip X-Forwarded-For,防止伪造readip。 func (ctx *contextBase) RealIP() string { if realip := ctx.RequestReader.Header.Get(HeaderXRealIP); realip != "" { return realip } if xforward := ctx.RequestReader.Header.Get(HeaderXForwardedFor); xforward != "" { - return strings.SplitN(string(xforward), ",", 2)[0] + return strings.SplitN(xforward, ",", 2)[0] } addr := strings.SplitN(ctx.RequestReader.RemoteAddr, ":", 2)[0] if addr == "pipe" { @@ -353,12 +366,12 @@ func (ctx *contextBase) RealIP() string { return addr } -// RequestID 获取响应中的X-Request-Id Header +// RequestID 获取响应中的X-Request-Id Header。 func (ctx *contextBase) RequestID() string { return ctx.GetHeader(HeaderXRequestID) } -// ContentType 获取请求内容类型,返回Content-Type Header +// ContentType 获取请求内容类型,返回Content-Type Header。 func (ctx *contextBase) ContentType() string { return ctx.GetHeader(HeaderContentType) } @@ -368,40 +381,40 @@ func (ctx *contextBase) Istls() bool { return ctx.RequestReader.TLS != nil } -var noneSliceByte = make([]byte, 0, 0) +var noneSliceByte = make([]byte, 0) // Body 返回请求的body,并保存到缓存中,可重复调用Body方法, // 每次调用会重置ctx.Request().Body对象成一个body reader。 // // ctx.bodyContent 不会随contextBase一起内存复用,正常应该避免调用Body方法; -// 如果使用应该设置middleware.NewBodyLimitFunc。 +// 如果使用应该设置middleware.NewBodyLimitFunc,避免超大body消耗内存。 func (ctx *contextBase) Body() []byte { if ctx.bodyContent == nil { - bts, err := ioutil.ReadAll(ctx.RequestReader.Body) + body, err := io.ReadAll(ctx.RequestReader.Body) if err != nil { ctx.bodyContent = noneSliceByte - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Body").Error(err) + ctx.wrapLogger().WithField(ParamCaller, "Context.Body").Error(err) return nil } - ctx.bodyContent = bts + ctx.bodyContent = body } - ctx.RequestReader.Body = ioutil.NopCloser(bytes.NewReader(ctx.bodyContent)) + ctx.RequestReader.Body = io.NopCloser(bytes.NewReader(ctx.bodyContent)) return ctx.bodyContent } // Bind 使用Bind解析请求body并绑定数据。 // 如果Validate不为空,则使用Validate校验数据。 -func (ctx *contextBase) Bind(i interface{}) error { +func (ctx *contextBase) Bind(i any) error { err := ctx.config.Bind(ctx, i) if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Bind").Error(err) - return NewErrorStatusCode(err, StatucBindFail, CodeBindFail) + ctx.wrapLogger().WithField(ParamCaller, "Context.Bind").Error(err) + return NewErrorWithStatusCode(err, DefaultHandlerDataStatus[0], DefaultHandlerDataCode[0]) } - if ctx.config.Validate != nil { - err = ctx.config.Validate(ctx, i) + if ctx.config.Validater != nil { + err = ctx.config.Validater(ctx, i) if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Bind(Validate)").Error(err) - return NewErrorStatusCode(err, StatucValidateFail, CodeValidateFail) + ctx.wrapLogger().WithField(ParamCaller, "Context.Bind").Error(err) + return NewErrorWithStatusCode(err, DefaultHandlerDataStatus[1], DefaultHandlerDataCode[1]) } } return nil @@ -409,36 +422,44 @@ func (ctx *contextBase) Bind(i interface{}) error { // Params 获得请求的全部参数。 func (ctx *contextBase) Params() *Params { - return &ctx.httpParams + return &ctx.params } // GetParam 方法获取一个参数的值。 func (ctx *contextBase) GetParam(key string) string { - return ctx.httpParams.Get(key) + return ctx.params.Get(key) } // SetParam 方法设置一个参数。 func (ctx *contextBase) SetParam(key, val string) { - ctx.httpParams = ctx.httpParams.Set(key, val) + ctx.params = ctx.params.Set(key, val) } -// Querys 方法返回http请求的全部uri参数。 +// Querys 方法返回http请求的全部uri参数,数据存储在Request().Form。 func (ctx *contextBase) Querys() url.Values { - err := ctx.RequestReader.ParseForm() - if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Querys").Error(err) + r := ctx.RequestReader + if r.Form == nil { + var err error + r.Form, err = url.ParseQuery(r.URL.RawQuery) + if err != nil { + ctx.wrapLogger().WithField(ParamCaller, "Context.Querys").Error(err) + } } - return ctx.RequestReader.Form + return r.Form } -// GetQuery 方法获得一个uri参数的值。 +// GetQuery 方法获得一个uri参数的值,数据存储在Request().Form。 func (ctx *contextBase) GetQuery(key string) string { - err := ctx.RequestReader.ParseForm() - if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.GetQuery").Error(err) - return "" + r := ctx.RequestReader + if r.Form == nil { + var err error + r.Form, err = url.ParseQuery(r.URL.RawQuery) + if err != nil { + ctx.wrapLogger().WithField(ParamCaller, "Context.GetQuery").Error(err) + return "" + } } - return ctx.RequestReader.Form.Get(key) + return r.Form.Get(key) } // GetHeader 方法获取一个请求header,相当于ctx.Request().Header().Get(name)。 @@ -446,18 +467,18 @@ func (ctx *contextBase) GetHeader(name string) string { return ctx.RequestReader.Header.Get(name) } -// SetHeader 方法设置一个响应header,相当于ctx.Response().Header().Set(name, val) +// SetHeader 方法设置一个响应header,相当于ctx.Response().Header().Set(name, val)。 func (ctx *contextBase) SetHeader(name string, val string) { ctx.ResponseWriter.Header().Set(name, val) } -// Cookies 方法获取全部请求的cookie,获取的cookie值是首次调用Cookies/GetCookie方法后解析的数据。。 +// Cookies 方法获取全部请求的cookie,获取的cookie值是首次调用Cookies/GetCookie方法后解析的数据。 func (ctx *contextBase) Cookies() []Cookie { ctx.readCookies() return ctx.cookies } -// GetCookie 获方法得一个请求cookie的值,获取的cookie值是首次调用Cookies/GetCookie方法后解析的数据。。 +// GetCookie 获方法得一个请求cookie的值,获取的cookie值是首次调用Cookies/GetCookie方法后解析的数据。 func (ctx *contextBase) GetCookie(name string) string { ctx.readCookies() for _, cookie := range ctx.cookies { @@ -468,8 +489,8 @@ func (ctx *contextBase) GetCookie(name string) string { return "" } -// SetCookie 方法设置一个响应cookie的数据,设置响应 Set-Cookie header。 -func (ctx *contextBase) SetCookie(cookie *SetCookie) { +// SetCookie 方法设置一个响应cookie的数据,设置响应header Set-Cookie,运行设置各种自定义cookie。 +func (ctx *contextBase) SetCookie(cookie *CookieSet) { if v := cookie.String(); v != "" { ctx.ResponseWriter.Header().Add(HeaderSetCookie, v) } @@ -477,60 +498,136 @@ func (ctx *contextBase) SetCookie(cookie *SetCookie) { // SetCookieValue 方法设置一个响应cookie,如果maxAge非0则设置Max-Age属性。 func (ctx *contextBase) SetCookieValue(name, value string, maxAge int) { - if maxAge != 0 { - ctx.ResponseWriter.Header().Add(HeaderSetCookie, fmt.Sprintf("%s=%s; Max-Age=%d", name, url.QueryEscape(value), maxAge)) - } else { - ctx.ResponseWriter.Header().Add(HeaderSetCookie, fmt.Sprintf("%s=%s;", name, url.QueryEscape(value))) - } + ctx.SetCookie(&CookieSet{ + Name: name, + Value: value, + MaxAge: maxAge, + }) } -// FormValue 使用body解析成Form数据,并返回对应key的值 +// FormValue 使用body解析成Form数据,并返回对应key的值。 func (ctx *contextBase) FormValue(key string) string { - if ctx.parseForm() != nil { - return "" + r := ctx.RequestReader + if r.PostForm == nil { + err := parseForm(r) + if err != nil { + r.PostForm = make(url.Values) + ctx.wrapLogger().WithField(ParamCaller, "Context.FormValue").Error(err) + return "" + } } - val, ok := ctx.RequestReader.MultipartForm.Value[key] + + val, ok := r.PostForm[key] if ok && len(val) != 0 { return val[0] } return "" } -// FormValues 使用body解析成Form数据,并返回全部的值 +// FormValues 使用body解析成Form数据,并返回全部的值。 func (ctx *contextBase) FormValues() map[string][]string { - if ctx.parseForm() != nil { - return nil + r := ctx.RequestReader + if r.PostForm == nil { + err := parseForm(r) + if err != nil { + r.PostForm = make(url.Values) + ctx.wrapLogger().WithField(ParamCaller, "Context.FormValues").Error(err) + return nil + } } - return ctx.RequestReader.MultipartForm.Value + return r.PostForm } -// FormFile 使用body解析成Form数据,并返回对应key的文件 +// FormFile 使用body解析成Form数据,并返回对应key的文件。 func (ctx *contextBase) FormFile(key string) *multipart.FileHeader { - if ctx.parseForm() != nil { - return nil + r := ctx.RequestReader + if r.PostForm == nil { + err := parseForm(r) + if err != nil { + r.PostForm = make(url.Values) + ctx.wrapLogger().WithField(ParamCaller, "Context.FormFile").Error(err) + return nil + } } - val, ok := ctx.RequestReader.MultipartForm.File[key] - if ok && len(val) != 0 { - return val[0] + + if r.MultipartForm != nil { + val, ok := r.MultipartForm.File[key] + if ok && len(val) != 0 { + return val[0] + } } return nil } // FormFiles 使用body解析成Form数据,并返回全部的文件。 func (ctx *contextBase) FormFiles() map[string][]*multipart.FileHeader { - if ctx.parseForm() != nil { - return nil + r := ctx.RequestReader + if r.PostForm == nil { + err := parseForm(r) + if err != nil { + r.PostForm = make(url.Values) + ctx.wrapLogger().WithField(ParamCaller, "Context.FormFiles").Error(err) + return nil + } } - return ctx.RequestReader.MultipartForm.File + + if r.MultipartForm != nil { + return r.MultipartForm.File + } + return nil } -// parseForm 解析form数据。 -func (ctx *contextBase) parseForm() error { - err := ctx.RequestReader.ParseMultipartForm(DefaultContextFormMaxMemory) - if err != nil && err.Error() != "http: multipart handled by MultipartReader" { - ctx.contextValues.Logger.WithField("depth", 2).WithField(ParamCaller, "Context.Form...").Error(err) +// parseForm 函数解析form数据,不会将PostForm数据复制到Form。 +// +// 如果Body为http.NoBody时PostForm = Form。 +func parseForm(r *http.Request) error { + if r.Body == http.NoBody { + if r.Form == nil { + var err error + r.Form, err = url.ParseQuery(r.URL.RawQuery) + if err != nil { + return err + } + } + r.PostForm = r.Form + return nil + } + + t, params, err := mime.ParseMediaType(r.Header.Get(HeaderContentType)) + if err != nil { return err } + switch t { + case MimeApplicationForm: + var reader io.Reader = r.Body + if reflect.TypeOf(reader).String() != "*http.maxBytesReader" { + reader = io.LimitReader(r.Body, DefaultContextMaxApplicationFormSize) + } + body, err := io.ReadAll(reader) + if err != nil { + return err + } + + val, err := url.ParseQuery(string(body)) + if err != nil { + return err + } + r.PostForm = val + case MimeMultipartForm, MimeMultipartMixed: + boundary, ok := params["boundary"] + if !ok { + return http.ErrMissingBoundary + } + + form, err := multipart.NewReader(r.Body, boundary).ReadForm(DefaultContextMaxMultipartFormMemory) + if err != nil { + return err + } + r.PostForm = form.Value + r.MultipartForm = form + default: + return fmt.Errorf(ErrFormatContextParseFormNotSupportContentType, t) + } return nil } @@ -541,73 +638,67 @@ func (ctx *contextBase) WriteHeader(code int) { // Redirect implement request redirection. // -// Redirect 实现请求重定向。 +// Redirect 实现请求重定向,状态码需要为30x或201。 func (ctx *contextBase) Redirect(code int, url string) { + if (code < http.StatusMultipleChoices || code > http.StatusPermanentRedirect) && code != StatusCreated { + ctx.wrapLogger().WithField(ParamCaller, "Context.Redirect").Error(fmt.Errorf(ErrFormatContextRedirectInvalid, code)) + return + } http.Redirect(ctx.ResponseWriter, ctx.RequestReader, url, code) } -// Push 实现http2 push +// Push 方法实现http2 push。 +// +// support of HTTP/2 Server Push will be disabled by default in +// Chrome 106 and other Chromium-based browsers. func (ctx *contextBase) Push(target string, opts *http.PushOptions) error { - if opts == nil { - opts = &http.PushOptions{ - Header: http.Header{ - HeaderAcceptEncoding: []string{ctx.RequestReader.Header.Get(HeaderAcceptEncoding)}, - }, - } - } - err := ctx.ResponseWriter.Push(target, opts) - if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Push"). - Errorf("Failed to push: %v, Resource path: %s.", err, target) + if err != nil && (errors.Is(err, http.ErrNotSupported) || DefaultContextPushNotSupportedError) { + err = fmt.Errorf(ErrFormatContextPushFailed, target, err) + ctx.wrapLogger().WithField(ParamCaller, "Context.Push").Error(err) } return err } // Render 使用app.Renderer返回数据。 -func (ctx *contextBase) Render(i interface{}) error { +func (ctx *contextBase) Render(i any) error { var err error - if ctx.config.Filte != nil { - err = ctx.config.Filte(ctx, i) + if ctx.config.Filter != nil { + err = ctx.config.Filter(ctx, i) if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Render(Filte)").Error(err) - return NewErrorStatusCode(err, StatucFilteFail, CodeFilteFail) + ctx.wrapLogger().WithField(ParamCaller, "Context.Render").Error(err) + return NewErrorWithStatusCode(err, DefaultHandlerDataStatus[2], DefaultHandlerDataCode[2]) } } err = ctx.config.Render(ctx, i) if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Render").Error(err) + ctx.wrapLogger().WithField(ParamCaller, "Context.Render").Error(err) } - return NewErrorStatusCode(err, StatucRenderFail, CodeRenderFail) + return NewErrorWithStatusCode(err, DefaultHandlerDataStatus[3], DefaultHandlerDataCode[3]) } // Write 实现io.Writer,向响应写入数据。 func (ctx *contextBase) Write(data []byte) (n int, err error) { n, err = ctx.ResponseWriter.Write(data) if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Write").Error(err) + ctx.wrapLogger().WithField(ParamCaller, "Context.Write").Error(err) } return } // WriteString 实现向响应写入一个字符串。 -func (ctx *contextBase) WriteString(i string) (err error) { - header := ctx.ResponseWriter.Header() - if val := header.Get(HeaderContentType); len(val) == 0 { - header.Add(HeaderContentType, MimeTextPlainCharsetUtf8) - } - _, err = ctx.ResponseWriter.Write([]byte(i)) +func (ctx *contextBase) WriteString(data string) (n int, err error) { + n, err = ctx.ResponseWriter.WriteString(data) if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.WriteString").Error(err) + ctx.wrapLogger().WithField(ParamCaller, "Context.WriteString").Error(err) } return } // WriteFile 使用HandlerFile处理一个静态文件。 -func (ctx *contextBase) WriteFile(path string) error { +func (ctx *contextBase) WriteFile(path string) { http.ServeFile(ctx.ResponseWriter, ctx.RequestReader, path) - return nil } // writeError 方法返回error数据,该方法不应该被直接使用,调用ctx.Fatal方法会自动调用writeError方法。 @@ -617,32 +708,33 @@ func (ctx *contextBase) writeError(err error) { w := ctx.ResponseWriter if w.Size() == 0 { status := w.Status() - if status == 200 { + if status == StatusOK { ctx.WriteHeader(getErrorStatus(err)) } - ctx.Render(NewContextMessgae(ctx, err, nil)) + _ = ctx.Render(NewContextMessgae(ctx, err, nil)) } ctx.contextValues.Error = err ctx.End() } +type contextMessage struct { + Time string `json:"time" protobuf:"1,name=time" xml:"time" yaml:"time"` + Host string `json:"host" protobuf:"2,name=host" xml:"host" yaml:"host"` + Method string `json:"method" protobuf:"3,name=method" xml:"method" yaml:"method"` + Path string `json:"path" protobuf:"4,name=path" xml:"path" yaml:"path"` + Route string `json:"route" protobuf:"5,name=route" xml:"route" yaml:"route"` + Status int `json:"status" protobuf:"6,name=status" xml:"status" yaml:"status"` + Code int `json:"code,omitempty" protobuf:"7,name=code" xml:"code,omitempty" yaml:"code,omitempty"` + XRequestID string `json:"x-request-id,omitempty" protobuf:"8,name=x-request-id" xml:"x-request-id,omitempty" yaml:"x-request-id,omitempty"` + XTraceID string `json:"x-trace-id,omitempty" protobuf:"9,name=x-trace-id" xml:"x-trace-id,omitempty" yaml:"x-trace-id,omitempty"` + Error string `json:"error,omitempty" protobuf:"10,name=error" xml:"error,omitempty" yaml:"error,omitempty"` + Message any `json:"message,omitempty" protobuf:"11,name=message" xml:"message,omitempty" yaml:"message,omitempty"` +} + // NewContextMessgae 方法从请求上下文创建一个error或对象响应对象,记录请求上下文相关信息。 -func NewContextMessgae(ctx Context, err error, message interface{}) interface{} { - type contextMessage struct { - Time string `json:"time" protobuf:"1,name=time" xml:"time" yaml:"time"` - Host string `json:"host" protobuf:"2,name=host" xml:"host" yaml:"host"` - Method string `json:"method" protobuf:"3,name=method" xml:"method" yaml:"method"` - Path string `json:"path" protobuf:"4,name=path" xml:"path" yaml:"path"` - Route string `json:"route" protobuf:"5,name=route" xml:"route" yaml:"route"` - Status int `json:"status" protobuf:"6,name=status" xml:"status" yaml:"status"` - Code int `json:"code,omitempty" protobuf:"7,name=code" xml:"code,omitempty" yaml:"code,omitempty"` - XRequestID string `json:"x-request-id,omitempty" protobuf:"8,name=x-request-id" xml:"x-request-id,omitempty" yaml:"x-request-id,omitempty"` - XTraceID string `json:"x-trace-id,omitempty" protobuf:"9,name=x-trace-id" xml:"x-trace-id,omitempty" yaml:"x-trace-id,omitempty"` - Error string `json:"error,omitempty" protobuf:"10,name=error" xml:"error,omitempty" yaml:"error,omitempty"` - Message interface{} `json:"message,omitempty" protobuf:"11,name=message" xml:"message,omitempty" yaml:"message,omitempty"` - } +func NewContextMessgae(ctx Context, err error, message any) any { msg := contextMessage{ - Time: time.Now().Format(DefaultLoggerTimeFormat), + Time: time.Now().Format(DefaultLoggerFormatterFormatTime), Host: ctx.Host(), Method: ctx.Method(), Path: ctx.Path(), @@ -660,91 +752,79 @@ func NewContextMessgae(ctx Context, err error, message interface{}) interface{} } func getErrorStatus(err error) int { - for { - stater, ok := err.(interface{ Status() int }) - if ok { + for err != nil { + if stater, ok := err.(interface{ Status() int }); ok { //nolint:errorlint return stater.Status() } - u, ok := err.(interface { - Unwrap() error - }) - if ok { - err = u.Unwrap() - } else { - return 500 - } + err = errors.Unwrap(err) } + return StatusInternalServerError } func getErrorCode(err error) int { - for { - coder, ok := err.(interface{ Code() int }) - if ok { + for err != nil { + if coder, ok := err.(interface{ Code() int }); ok { //nolint:errorlint return coder.Code() } - u, ok := err.(interface { - Unwrap() error - }) - if ok { - err = u.Unwrap() - } else { - return 0 - } + err = errors.Unwrap(err) } + return 0 } // Query 方法调用Database.Query查询数据块。 -func (ctx *contextBase) Query(data interface{}, stmt DatabaseStmt) error { - err := ctx.contextValues.Database.Query(ctx.context, data, stmt) - if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Query").Error(err) +func (ctx *contextBase) Query(data any, stmt DatabaseStmt) error { + if ctx.config.DatabaseRuntime != nil { + stmt = ctx.config.DatabaseRuntime(ctx, stmt) } - return err + return ctx.contextValues.Database.Query(ctx.context, data, stmt) } // Exec 方法调用Database.Exec执行数据块。 func (ctx *contextBase) Exec(stmt DatabaseStmt) error { - err := ctx.contextValues.Database.Exec(ctx.context, stmt) - if err != nil { - ctx.contextValues.Logger.WithField("depth", 1).WithField(ParamCaller, "Context.Exec").Error(err) + if ctx.config.DatabaseRuntime != nil { + stmt = ctx.config.DatabaseRuntime(ctx, stmt) } - return err + return ctx.contextValues.Database.Exec(ctx.context, stmt) } -func (ctx *contextBase) NewRequest(method, path string, options ...interface{}) error { +func (ctx *contextBase) NewRequest(method, path string, options ...any) error { return ctx.contextValues.Client.NewRequest(ctx.context, method, path, options...) } +func (ctx *contextBase) wrapLogger() Logger { + return ctx.contextValues.WithField(ParamDepth, 1) +} + // Debug 方法写入Debug日志。 -func (ctx *contextBase) Debug(args ...interface{}) { - ctx.contextValues.Logger.WithField("depth", 1).Debug(args...) +func (ctx *contextBase) Debug(args ...any) { + ctx.wrapLogger().Debug(args...) } // Info 方法写入Info日志。 -func (ctx *contextBase) Info(args ...interface{}) { - ctx.contextValues.Logger.WithField("depth", 1).Info(args...) +func (ctx *contextBase) Info(args ...any) { + ctx.wrapLogger().Info(args...) } // Warning 方法写入Warning日志。 -func (ctx *contextBase) Warning(args ...interface{}) { - ctx.contextValues.Logger.WithField("depth", 1).Warning(args...) +func (ctx *contextBase) Warning(args ...any) { + ctx.wrapLogger().Warning(args...) } // Error 方法写入Error日志。 -func (ctx *contextBase) Error(args ...interface{}) { - ctx.contextValues.Logger.WithField("depth", 1).Error(args...) +func (ctx *contextBase) Error(args ...any) { + ctx.wrapLogger().Error(args...) } // Fatal 方法写入Error日志,并结束请求上下文处理。 // // 注意:如果err中存在敏感信息会被写入到响应中。 -func (ctx *contextBase) Fatal(args ...interface{}) { +func (ctx *contextBase) Fatal(args ...any) { err := getMessagError(args) ctx.writeError(err) - ctx.contextValues.Logger.WithField("depth", 1).Error(err.Error()) + ctx.wrapLogger().Error(err.Error()) } -func getMessagError(args []interface{}) error { +func getMessagError(args []any) error { if len(args) == 1 { err, ok := args[0].(error) if ok { @@ -757,74 +837,74 @@ func getMessagError(args []interface{}) error { } // Debugf 方法输出Info日志。 -func (ctx *contextBase) Debugf(format string, args ...interface{}) { - ctx.contextValues.Logger.WithField("depth", 1).Debug(fmt.Sprintf(format, args...)) +func (ctx *contextBase) Debugf(format string, args ...any) { + ctx.wrapLogger().Debug(fmt.Sprintf(format, args...)) } // Infof 方法输出Info日志。 -func (ctx *contextBase) Infof(format string, args ...interface{}) { - ctx.contextValues.Logger.WithField("depth", 1).Info(fmt.Sprintf(format, args...)) +func (ctx *contextBase) Infof(format string, args ...any) { + ctx.wrapLogger().Info(fmt.Sprintf(format, args...)) } // Warningf 方法输出Warning日志。 -func (ctx *contextBase) Warningf(format string, args ...interface{}) { - ctx.contextValues.Logger.WithField("depth", 1).Warning(fmt.Sprintf(format, args...)) +func (ctx *contextBase) Warningf(format string, args ...any) { + ctx.wrapLogger().Warning(fmt.Sprintf(format, args...)) } // Errorf 方法输出Error日志。 -func (ctx *contextBase) Errorf(format string, args ...interface{}) { - ctx.contextValues.Logger.WithField("depth", 1).Error(fmt.Sprintf(format, args...)) +func (ctx *contextBase) Errorf(format string, args ...any) { + ctx.wrapLogger().Error(fmt.Sprintf(format, args...)) } // Fatalf 方法输出Fatal日志,并结束请求上下文处理。 // // 注意:如果err中存在敏感信息会被写入到响应中。 -func (ctx *contextBase) Fatalf(format string, args ...interface{}) { +func (ctx *contextBase) Fatalf(format string, args ...any) { msg := fmt.Sprintf(format, args...) ctx.writeError(errors.New(msg)) - ctx.contextValues.Logger.WithField("depth", 1).Errorf(msg) + ctx.wrapLogger().Errorf(msg) } // WithField 方法增加一个日志属性,返回一个新的Logger。 -func (ctx *contextBase) WithField(key string, value interface{}) Logger { +func (ctx *contextBase) WithField(key string, value any) Logger { return &contextBaseEntry{ - Logger: ctx.contextValues.Logger.WithField(key, value), - Context: ctx, + Logger: ctx.contextValues.WithField(key, value), + writeError: ctx.writeError, } } // WithFields 方法增加多个日志属性,返回一个新的Logger。 // // 如果fields包含file条目属性,则不会添加调用位置信息。 -func (ctx *contextBase) WithFields(keys []string, fields []interface{}) Logger { +func (ctx *contextBase) WithFields(keys []string, fields []any) Logger { return &contextBaseEntry{ - Logger: ctx.contextValues.Logger.WithFields(keys, fields), - Context: ctx, + Logger: ctx.contextValues.WithFields(keys, fields), + writeError: ctx.writeError, } } // Fatal 方法重写Context的Fatal方法,不执行panic,http返回500和请求id。 -func (e *contextBaseEntry) Fatal(args ...interface{}) { +func (e *contextBaseEntry) Fatal(args ...any) { err := getMessagError(args) - e.Context.writeError(err) - e.Logger.WithField("depth", 1).Error(err.Error()) + e.writeError(err) + e.Error(err.Error()) } // Fatalf 方法重写Context的Fatalf方法,不执行panic,http返回500和请求id。 -func (e *contextBaseEntry) Fatalf(format string, args ...interface{}) { +func (e *contextBaseEntry) Fatalf(format string, args ...any) { msg := fmt.Sprintf(format, args...) - e.Context.writeError(errors.New(msg)) - e.Logger.WithField("depth", 1).Error(msg) + e.writeError(errors.New(msg)) + e.Error(msg) } // WithField 方法增加一个日志属性。 -func (e *contextBaseEntry) WithField(key string, value interface{}) Logger { +func (e *contextBaseEntry) WithField(key string, value any) Logger { e.Logger = e.Logger.WithField(key, value) return e } // WithFields 方法增加多个日志属性。 -func (e *contextBaseEntry) WithFields(keys []string, fields []interface{}) Logger { +func (e *contextBaseEntry) WithFields(keys []string, fields []any) Logger { e.Logger = e.Logger.WithFields(keys, fields) return e } @@ -835,20 +915,19 @@ func (ctx *contextBase) readCookies() { return } for _, line := range ctx.RequestReader.Header[HeaderCookie] { - parts := strings.Split(line, "; ") - // Per-line attributes - for i := 0; i < len(parts); i++ { - if len(parts[i]) == 0 { + line = textproto.TrimString(line) + var part string + for len(line) > 0 { // continue since we have rest + part, line, _ = strings.Cut(line, ";") + part = textproto.TrimString(part) + if part == "" { continue } - name, val := parts[i], "" - if j := strings.Index(name, "="); j >= 0 { - name, val = name[:j], name[j+1:] - } + name, val, _ := strings.Cut(part, "=") if !isCookieNameValid(name) { continue } - val, ok := parseCookieValue(val, true) + val, ok := parseCookieValue(val) if !ok { continue } @@ -857,9 +936,37 @@ func (ctx *contextBase) readCookies() { } } -func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) { - // Strip the quotes, if present. - if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +// String 方法返回Cookie格式化字符串。 +func (c Cookie) String() string { + v := sanitizeCookieValue(c.Value) + if strings.ContainsAny(v, " ,") { + return `"` + v + `"` + } + return cookieNameSanitizer.Replace(c.Name) + ":" + v +} + +func sanitizeCookieValue(v string) string { + for i := 0; i < len(v); i++ { + if validCookieValueByte(v[i]) { + continue + } + + buf := make([]byte, 0, len(v)) + buf = append(buf, v[:i]...) + for ; i < len(v); i++ { + if b := v[i]; validCookieValueByte(b) { + buf = append(buf, b) + } + } + return string(buf) + } + return v +} + +func parseCookieValue(raw string) (string, bool) { + if len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { raw = raw[1 : len(raw)-1] } for i := 0; i < len(raw); i++ { @@ -883,10 +990,10 @@ func isCookieNameValid(raw string) bool { func isNotToken(r rune) bool { i := int(r) - return !(i < len(isTokenTable) && isTokenTable[i]) + return !(i < len(tableCookie) && tableCookie[i]) } -var isTokenTable = [127]bool{ +var tableCookie = [127]bool{ '!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true, '+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true, 'B': true, 'C': true, 'D': true, @@ -907,7 +1014,8 @@ func (ctx *contextBaseValue) Reset(c context.Context, config *contextBaseConfig) ctx.Error = nil ctx.Values = ctx.Values[0:0] } -func (ctx *contextBaseValue) SetValue(key, val interface{}) { + +func (ctx *contextBaseValue) SetValue(key, val any) { switch key { case ContextKeyLogger: ctx.Logger = val.(Logger) @@ -926,7 +1034,7 @@ func (ctx *contextBaseValue) SetValue(key, val interface{}) { } } -func (ctx *contextBaseValue) Value(key interface{}) interface{} { +func (ctx *contextBaseValue) Value(key any) any { switch key { case ContextKeyLogger: return ctx.Logger @@ -968,29 +1076,42 @@ func (w *responseWriterHTTP) Reset(writer http.ResponseWriter) { w.size = 0 } +// Unwrap 方法返回原始http.ResponseWrite对象。 +func (w *responseWriterHTTP) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + // Write 方法实现io.Writer接口。 func (w *responseWriterHTTP) Write(data []byte) (int, error) { w.writeStatus() n, err := w.ResponseWriter.Write(data) - w.size = w.size + n + w.size += n + return n, err +} + +func (w *responseWriterHTTP) WriteString(data string) (int, error) { + w.writeStatus() + n, err := io.WriteString(w.ResponseWriter, data) + w.size += n return n, err } // WriteHeader 方法实现写入http请求状态码。 -func (w *responseWriterHTTP) WriteHeader(codeCode int) { - w.code = codeCode +func (w *responseWriterHTTP) WriteHeader(code int) { + w.code = code } func (w *responseWriterHTTP) writeStatus() { if w.code > 0 && w.code != 200 { w.ResponseWriter.WriteHeader(w.code) - w.code *= -1 + w.code = -w.code } } // Flush 方法实现刷新缓冲,将缓冲的请求发送给客户端。 func (w *responseWriterHTTP) Flush() { if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + w.writeStatus() flusher.Flush() } } @@ -998,9 +1119,10 @@ func (w *responseWriterHTTP) Flush() { // Hijack 方法实现劫持http连接,用于websocket连接。 func (w *responseWriterHTTP) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok { + w.code = -StatusSwitchingProtocols return hijacker.Hijack() } - return nil, nil, ErrResponseWriterHTTPNotHijacker + return nil, nil, ErrResponseWriterNotHijacker } // Push 方法实现http Psuh,如果responseWriterHTTP实现http.Push接口,则Push资源。 @@ -1019,7 +1141,7 @@ func (w *responseWriterHTTP) Size() int { // Status 方法获得设置的http状态码。 func (w *responseWriterHTTP) Status() int { if w.code < 0 { - return w.code * -1 + return -w.code } return w.code } diff --git a/controller.go b/controller.go index 426479c..4a510c3 100644 --- a/controller.go +++ b/controller.go @@ -11,31 +11,29 @@ import ( Controller defines the controller interface. The default AutoRoute controller implements the following functions: + The controller method maps the route method path. - Controller construction error delivery (NewControllerError) - Custom controller function mapping relationship (implement func ControllerRoute() map[string]string) - Custom controller routing group and routing parameters (implement func ControllerParam(pkg, name, method string) string) - Controller routing combination, if a controller named xxxController is combined, the routing method of the xxx controller will be combined - Controller method combination, if you combine a controller with a name other than xxxController, you can directly call the method in the controller property ctl.xxx. + Controller construction error delivery. + Custom controller function mapping relationship. + Custom controller routing group and routing parameters. + Controller routing combination, get the routing method of the combined controller. + Controller method composition, using the calling method controlled by the composition. Controller 定义控制器接口。 默认AutoRoute控制器实现下列功能: + 控制器方法映射路由方法路径。 - 控制器构造错误传递(NewControllerError) - 自定义控制器函数映射关系(实现func ControllerRoute() map[string]string) - 自定义控制器路由组和路由参数(实现func ControllerParam(pkg, name, method string) string) - 控制器路由组合,如果组合一个名称为xxxController控制器,会组合获得xxx控制器的路由方法 - 控制器方法组合,如果组合一个名称非xxxController控制器,可以控制器属性ctl.xxx直接调用方法。 + 控制器构造错误传递(NewControllerError)。 + 自定义控制器函数映射关系(实现'func ControllerRoute() map[string]string)')。 + 自定义控制器路由组和路由参数(实现'func ControllerParam(pkg, name, method string) string')。 + 控制器路由组合,获得被组合控制器的路由方法。 + 控制器方法组合,使用被组合控制的的调用方法。 */ type Controller interface { Inject(Controller, Router) error } -type controllerName interface { - ControllerName() string -} - type controllerGroup interface { ControllerGroup(string) string } @@ -50,14 +48,15 @@ type controllerParam interface { ControllerParam(string, string, string) string } -// The ControllerAutoRoute implements the routing mapping controller to register the corresponding router method according to the method. +// The ControllerAutoRoute implements the routing mapping controller +// to register the corresponding router method according to the method. // // ControllerAutoRoute 实现路由映射控制器根据方法注册对应的路由器方法。 type ControllerAutoRoute struct{} type controllerError struct { + Controller Error error - Name string } // NewControllerError function returns a controller error, and the corresponding error is returned when the controller Inject. @@ -65,72 +64,83 @@ type controllerError struct { // NewControllerError 函数返回一个控制器错误,在控制器Inject时返回对应的错误。 func NewControllerError(ctl Controller, err error) Controller { return &controllerError{ - Error: err, - Name: getControllerPathName(ctl), + Controller: ctl, + Error: err, } } // The Inject method returns a controller error when injecting routing rules. // // Inject 方法在注入路由规则时返回控制器错误。 -func (ctl *controllerError) Inject(Controller, Router) error { +func (ctl controllerError) Inject(Controller, Router) error { return ctl.Error } -// The ControllerName method returns the controller name of controllerError. -// -// ControllerName 方法返回controllerError的控制器名称。 -func (ctl *controllerError) ControllerName() string { - return ctl.Name +func (ctl controllerError) Unwrap() Controller { + return ctl.Controller } // Inject method implements the method of injecting the controller into the router, // and the ControllerAutoRoute controller calls the ControllerInjectAutoRoute method to inject. // // Inject 方法实现控制器注入到路由器的方法,ControllerAutoRoute控制器调用ControllerInjectAutoRoute方法注入。 -func (ctl *ControllerAutoRoute) Inject(controller Controller, router Router) error { +func (ctl ControllerAutoRoute) Inject(controller Controller, router Router) error { return ControllerInjectAutoRoute(controller, router) } -// ControllerInjectAutoRoute function generates routing rules based on the controller rules, and the usage method is converted into a processing function to support routers. +// ControllerInjectAutoRoute function generates routing rules based on the controller rules, +// and the usage method is converted into a processing function to support routers. // -// Routing group: If the'ControllerGroup(string) string' method is implemented, the routing group is returned; if the routing parameter ParamControllerGroup is included, it is used; otherwise, the controller name is used to turn the path. +// Routing group: If the'ControllerGroup(string) string' method is implemented, +// the routing group is returned; if the routing parameter ParamControllerGroup is included, +// it is used; otherwise, the controller name is used to turn the path. // -// Routing path: Convert the method with the request method as the prefix to the routing method and path, and then use the map[method]path returned by the'ControllerRoute() map[string]string' method to overwrite the routing path. +// Routing path: Convert the method with the request method as the prefix to the routing method and path, +// and then use the map[method]path returned by the'ControllerRoute() map[string]string' method to overwrite the routing path. // -// Method conversion rules: The method prefix must be a valid request method (within DefaultRouterAllMethod), the remaining path is converted to a path, ByName is converted to variable matching/:name, and the last By of the method path is converted to /*; -// The return path of ControllerRoute is'-' and the method is ignored. The first character is'', which means it is a path append parameter. +// Method conversion rules: The method prefix must be a valid request method (within DefaultRouterAllMethod), +// the remaining path is converted to a path, ByName is converted to variable matching/:name, +// and the last By of the method path is converted to /*; +// The return path of ControllerRoute is'-' and the method is ignored. The first character is”, +// which means it is a path append parameter. // -// Routing parameters: If you implement the'ControllerParam(string, string, string) string' method to return routing parameters, otherwise use "controllername=%s.%s controllermethod=%s". +// Routing parameters: If you implement the'ControllerParam(string, string, string) string' method to return routing parameters, +// otherwise use "controllername=%s.%s controllermethod=%s". // -// Controller combination: If the controller combines other objects, only the methods of the object whose name suffix is ​​Controller are reserved, and other methods with embedded properties will be ignored. +// Controller combination: If the controller combines other objects, +// only the methods of the object whose name suffix is Controller are reserved, +// and other methods with embedded properties will be ignored. // // ControllerInjectAutoRoute 函数基于控制器规则生成路由规则,使用方法转换成处理函数支持路由器。 // -// 路由组: 如果实现'ControllerGroup(string) string'方法返回路由组;如果包含路由参数ParamControllerGroup则使用;否则使用控制器名称驼峰转路径。 +// 路由组: 如果实现'ControllerGroup(string) string'方法返回路由组; +// 如果包含路由参数ParamControllerGroup则使用;否则使用控制器名称驼峰转路径。 // -// 路由路径: 将请求方法为前缀的方法转换成路由方法和路径,然后使用'ControllerRoute() map[string]string'方法返回的map[method]path覆盖路由路径。 +// 路由路径: 将请求方法为前缀的方法转换成路由方法和路径, +// 然后使用'ControllerRoute() map[string]string'方法返回的map[method]path覆盖路由路径。 // -// 方法转换规则: 方法前缀必须是有效的请求方法(DefaultRouterAllMethod之内),剩余路径驼峰转路径,ByName转换成变量匹配/:name,方法路径最后一个By转换成/*; +// 方法转换规则: 方法前缀必须是有效的请求方法(DefaultRouterAllMethod之内), +// 剩余路径驼峰转路径,ByName转换成变量匹配/:name,方法路径最后一个By转换成/*; // ControllerRoute返回路径为'-'则忽略方法,第一个字符为' '表示为路径追加参数。 // -// 路由参数: 如果实现'ControllerParam(string, string, string) string'方法返回路由参数,否则使用"controllername=%s.%s controllermethod=%s"。 +// 路由参数: 如果实现'ControllerParam(string, string, string) string'方法返回路由参数, +// 否则使用"controllername=%s.%s controllermethod=%s"。 // // 控制器组合: 如果控制器组合了其他对象,仅保留名称后缀为Controller的对象的方法,其他嵌入属性的方法将被忽略。 func ControllerInjectAutoRoute(controller Controller, router Router) error { iType := reflect.TypeOf(controller) - iValue := reflect.ValueOf(controller) + v := reflect.ValueOf(controller) // 添加控制器组。 - cname := getControllerName(reflect.Indirect(iValue)) - cpkg := reflect.Indirect(iValue).Type().PkgPath() + cname := getControllerName(v) + cpkg := reflect.Indirect(v).Type().PkgPath() router = router.Group(getContrllerRouterGroup(controller, cname, router)) // 获取路由参数函数 pfn := defaultRouteParam - v, ok := controller.(controllerParam) + p, ok := controller.(controllerParam) if ok { - pfn = v.ControllerParam + pfn = p.ControllerParam } // 路由器注册控制器方法 @@ -141,13 +151,16 @@ func ControllerInjectAutoRoute(controller Controller, router Router) error { continue } - h := iValue.Method(m.Index).Interface() + h := v.Method(m.Index).Interface() SetHandlerAliasName(h, fmt.Sprintf("%s.%s.%s", cpkg, cname, name)) method := getMethodByName(name) if method == "" { - method = "ANY" + method = MethodAny + } + err := router.AddHandler(method, paths[i]+" "+pfn(cpkg, cname, name), h) + if err != nil { + return err } - router.AddHandler(method, paths[i]+" "+pfn(cpkg, cname, name), h) } return nil } @@ -179,7 +192,7 @@ func getContrllerRouterGroup(controller Controller, name string, router Router) // defaultRouteParam 函数定义默认的控制器参数,可以通过实现controllerParam来覆盖该函数。 func defaultRouteParam(pkg, name, method string) string { - return fmt.Sprintf("controllername=%s.%s controllermethod=%s", pkg, name, method) + return fmt.Sprintf("controller=%s.%s controllermethod=%s", pkg, name, method) } func getSortMapValue(data map[string]string) ([]string, []string) { @@ -227,20 +240,20 @@ func getControllerRoutes(controller Controller) map[string]string { return routes } -func getContrllerAllowMethos(iValue reflect.Value) map[string]string { +func getContrllerAllowMethos(v reflect.Value) map[string]string { names := make(map[string]string) - for _, name := range getContrllerAllMethos(iValue) { + for _, name := range getContrllerAllMethos(v) { names[name] = "" } - iValue = reflect.Indirect(iValue) - iType := iValue.Type() - if iValue.Kind() == reflect.Struct { + v = reflect.Indirect(v) + iType := v.Type() + if v.Kind() == reflect.Struct { // 删除嵌入非控制器方法 - for i := 0; i < iValue.NumField(); i++ { + for i := 0; i < v.NumField(); i++ { if iType.Field(i).Anonymous { - if !strings.HasSuffix(getControllerName(iValue.Field(i)), "Controller") { - for _, name := range getContrllerAllMethos(iValue.Field(i)) { + if !strings.HasSuffix(getControllerName(v.Field(i)), "Controller") { + for _, name := range getContrllerAllMethos(v.Field(i)) { delete(names, name) } } @@ -249,8 +262,8 @@ func getContrllerAllowMethos(iValue reflect.Value) map[string]string { // 追加嵌入控制器方法 for i := 0; i < iType.NumField(); i++ { if iType.Field(i).Anonymous { - if strings.HasSuffix(getControllerName(iValue.Field(i)), "Controller") { - for _, name := range getContrllerAllMethos(iValue.Field(i)) { + if strings.HasSuffix(getControllerName(v.Field(i)), "Controller") { + for _, name := range getContrllerAllMethos(v.Field(i)) { names[name] = "" } } @@ -261,8 +274,8 @@ func getContrllerAllowMethos(iValue reflect.Value) map[string]string { } // getContrllerAllMethos 函数获得一共类型包含指针类型的全部方法名称。 -func getContrllerAllMethos(iValue reflect.Value) []string { - iType := iValue.Type() +func getContrllerAllMethos(v reflect.Value) []string { + iType := v.Type() if iType.Kind() != reflect.Ptr { iType = reflect.New(iType).Type() } @@ -273,16 +286,13 @@ func getContrllerAllMethos(iValue reflect.Value) []string { return names } -func getControllerName(iValue reflect.Value) string { - if iValue.Kind() == reflect.Ptr && iValue.IsNil() { - iValue = reflect.New(iValue.Type().Elem()) - } - var name string - if iValue.Type().Implements(typeControllerName) && iValue.CanSet() { - name = iValue.MethodByName("ControllerName").Call(nil)[0].String() - } else { - name = reflect.Indirect(iValue).Type().Name() +func getControllerName(v reflect.Value) string { + if v.Kind() == reflect.Ptr && v.IsNil() { + v = reflect.New(v.Type().Elem()) } + + name := reflect.Indirect(v).Type().Name() + // 泛型名称 pos := strings.IndexByte(name, '[') if pos != -1 { name = name[:pos] @@ -301,7 +311,7 @@ func getRouteByName(name string) string { if names[i] == "By" { i++ if i == len(names) { - name = name + "/*" + name += "/*" } else { name = name + "/:" + names[i] } @@ -334,19 +344,21 @@ func getFirstUp(name string) string { return name } -// splitTitleName 方法基于路径首字符大写切割 +// splitTitleName 方法基于路径首字符大写切割。 func splitTitleName(str string) []string { var body []byte for i := range str { - if i != 0 && byteIn(str[i], 0x40) && byteIn(str[i-1], 0x60) { + switch { + case i != 0 && byteIn(str[i], 0x40) && byteIn(str[i-1], 0x60): body = append(body, ' ') body = append(body, str[i]) - } else if i != 0 && i != len(str)-1 && byteIn(str[i], 0x40) && byteIn(str[i-1], 0x40) && byteIn(str[i+1], 0x60) { + case i != 0 && i != len(str)-1 && byteIn(str[i], 0x40) && + byteIn(str[i-1], 0x40) && byteIn(str[i+1], 0x60): body = append(body, ' ') body = append(body, str[i]) - } else if byteIn(str[i], 0x40) && i != 0 { + case byteIn(str[i], 0x40) && i != 0: body = append(body, str[i]+0x20) - } else { + default: body = append(body, str[i]) } } diff --git a/converter.go b/converter.go deleted file mode 100644 index 528b7d1..0000000 --- a/converter.go +++ /dev/null @@ -1,872 +0,0 @@ -package eudore - -/* -功能1:获取和设置一个对象的属性 -func Get(i interface{}, key string) interface{} -func GetWithTags(i interface{}, key string, tags []string) (interface{}, error) -func Set(i interface{}, key string, val interface{}) error -func SetWithTags(i interface{}, key string, val interface{}, tags []string) error - -功能2:map和结构体相互转换 -func ConvertMap(i interface{}) map[interface{}]interface{} -func ConvertMapString(i interface{}) map[string]interface{} -func ConvertMapStringWithTags(i interface{}, tags []string) map[string]interface{} -func ConvertMapWithTags(i interface{}, tags []string) map[interface{}]interface{} -func ConvertTo(source interface{}, target interface{}) error -func ConvertToWithTags(source interface{}, target interface{}, tags []string) error - -功能3:sql结果Rows绑定 -func ConvertRows(rows *sql.Rows, i interface{}) error -func ConvertRowsWithTags(rows *sql.Rows, i interface{}, tags []string) error -*/ - -import ( - "encoding/json" - "fmt" - "math" - "reflect" - "strconv" - "strings" - "time" -) - -type convertValue struct { - Tags []string - All bool - Keys []string - Index int - Value interface{} -} - -type converter struct { - tags []string - results map[reflect.Value]interface{} -} - -// Set the properties of an object. The object must be a pointer type. If the target implements the Seter interface, the Set method is called. -// -// The path will be split using '.' and then look for the path in turn. -// -// When the object type selected in the path is ptr, it will check if it is empty. If the object is empty, it will be initialized by default. -// -// When the object type selected in the path is interface{}, if the object is empty, it will be initialized to map[string]interface{}, otherwise the value will be judged according to the value type. -// -// When the object type selected in the path is array, the path is converted to an object index to set the array element. If it cannot be converted, the element is appended. -// -// When the object type in the path is selected as a struct, the attribute name and the attribute tag 'alias' are used to match when selecting the attribute. -// -// If the type of the value is a string, it will be converted according to the target type set. -// -// If the target type is a string, the value is output as a string and then assigned. -// -// If the target type is an array, map, or struct, the json deserializes the set object. -// -// If the target type passed in is an array, map, or struct, the json deserializes the set object. -// -// 设置一个对象的属性,改对象必须是指针类型,如果目标实现Seter接口,调用Set方法。 -// -// 路径将使用'.'分割,然后依次寻找路径。 -// -// 当路径中选择对象类型为ptr时,会检查是否为空,对象为空会默认进行初始化。 -// -// 当路径中选择对象类型为interface{}时,如果对象为空会初始化为map[string]interface{},否则按值类型来判断下一步操作。 -// -// 当路径中选择对象类型为array时,路径会转换成对象索引来设置数组元素,无法转换则追加元素。 -// -// 当路径中选择对象类型为struct时,选择属性时会使用属性名称和属性标签'alias'来匹配。 -// -// 如果值的类型是字符串,会根据设置的目标类型来转换。 -// -// 如果目标类型是字符串,将会值输出成字符串然后赋值。 -// -// 如果目标类型是数组、map、结构体,会使用json反序列化设置对象。 -// -// 如果传入的目标类型是数组、map、结构体,会使用json反序列化设置对象。 -func Set(i interface{}, key string, val interface{}) error { - return SetWithTags(i, key, val, DefaultGetSetTags, false) -} - -// SetWithTags 函数和Set函数相同,可以额外设置tags。 -func SetWithTags(i interface{}, key string, val interface{}, tags []string, all bool) error { - if i == nil || key == "" { - return ErrConverterInputDataNil - } - // 检测目标是指针类型。 - if reflect.TypeOf(i).Kind() != reflect.Ptr { - return ErrConverterInputDataNotPtr - } - return (&convertValue{ - Tags: tags, - All: all, - Keys: strings.Split(key, "."), - Value: val, - }).setValue(reflect.ValueOf(i)) -} - -func (v *convertValue) setValue(iValue reflect.Value) error { - if len(v.Keys) == 0 { - return setWithValue(reflect.ValueOf(v.Value), iValue) - } - switch iValue.Kind() { - case reflect.Ptr: - if iValue.IsNil() { - // 将空指针赋值 - iValue.Set(reflect.New(iValue.Type().Elem())) - } - return v.setValue(iValue.Elem()) - case reflect.Interface: - return v.setInterface(iValue) - case reflect.Struct: - return v.setStruct(iValue) - case reflect.Map: - return v.setMap(iValue) - case reflect.Slice: - return v.setSlice(iValue) - case reflect.Array: - return v.setArray(iValue) - } - return fmt.Errorf(ErrFormatConverterSetTypeError, iValue.Kind(), v.Keys, v.Value) -} - -// 设置接口类型 -func (v *convertValue) setInterface(iValue reflect.Value) (err error) { - // 如果是空接口,初始化为map[string]interface{}类型 - if iValue.IsNil() { - if iValue.Type() != typeInterface { - return nil - } - iValue.Set(reflect.ValueOf(make(map[string]interface{}))) - } - // 创建一个可取地址的临时变量,并设置值用于下一步设置。 - newValue := reflect.New(iValue.Elem().Type()).Elem() - newValue.Set(iValue.Elem()) - err = v.setValue(newValue) - // 将修改后的值重新赋值给对象 - if err == nil { - iValue.Set(newValue) - } - return err -} - -// 处理结构体设置属性 -func (v *convertValue) setStruct(iValue reflect.Value) error { - // 查找属性是结构体的第几个属性。 - var index = getStructIndexOfTags(iValue.Type(), v.Keys[0], v.Tags) - // 未找到直接返回错误。 - if index == -1 { - return fmt.Errorf(ErrFormatConverterSetStructNotField, v.Keys[0]) - } - - // 获取结构体的属性 - structField := iValue.Field(index) - if !structField.CanSet() { - return fmt.Errorf(ErrFormatConverterNotCanset, v.Keys[0], iValue.Type().String()) - } - v.Keys = v.Keys[1:] - return v.setValue(structField) -} - -// 处理map -func (v *convertValue) setMap(iValue reflect.Value) error { - iType := iValue.Type() - // 对空map初始化 - if iValue.IsNil() { - iValue.Set(reflect.MakeMap(iType)) - } - - // 创建map需要匹配的key - mapKey := reflect.New(iType.Key()).Elem() - setWithString(mapKey, v.Keys[0]) - - newValue := reflect.New(iType.Elem()).Elem() - mapvalue := iValue.MapIndex(mapKey) - if mapvalue.Kind() != reflect.Invalid { - newValue.Set(mapvalue) - } - - v.Keys = v.Keys[1:] - err := v.setValue(newValue) - // 将修改后的mapvalue重新赋值给map - if err == nil { - iValue.SetMapIndex(mapKey, newValue) - } - return err -} - -func (v *convertValue) setArray(iValue reflect.Value) error { - index, err := strconv.Atoi(v.Keys[0]) - if err != nil || index < 0 || index >= iValue.Len() { - return fmt.Errorf(ErrFormatConverterSetArrayIndexInvalid, v.Keys[0], iValue.Len()) - } - v.Keys = v.Keys[1:] - return v.setValue(iValue.Index(index)) -} - -// 处理数组和切片 -func (v *convertValue) setSlice(iValue reflect.Value) error { - iType := iValue.Type() - // 空切片初始化,默认长度2 - if iValue.IsNil() { - iValue.Set(reflect.MakeSlice(iType, 0, 2)) - } - // 创建新元素的类型和值 - newValue := reflect.New(iType.Elem()).Elem() - index, err := strconv.Atoi(v.Keys[0]) - if err != nil { - index = -1 - } - if index > -1 { - // 新建数组替换原数组扩容 - if iValue.Cap() <= index { - iValue.Set(reflect.AppendSlice(reflect.MakeSlice(iType, 0, index+1), iValue)) - } - // 对数组长度扩充,新元素添加空值 - if iValue.Len() <= index { - iValue.SetLen(index + 1) - } - // 将原数组值设置给newValue - newValue.Set(iValue.Index(index)) - } - - v.Keys = v.Keys[1:] - err = v.setValue(newValue) - if err == nil { - if index > -1 { - iValue.Index(index).Set(newValue) - } else { - iValue.Set(reflect.Append(iValue, newValue)) - } - } - return err -} - -// Get method A more path to get an attribute from an object. -// -// The path will be split using '.' and then look for the path in turn. -// -// Structure attributes can use the structure tag 'alias' to match attributes. -// -// Returns a null value if the match fails. -// -// 根据路径来从一个对象获得一个属性。 -// -// 路径将使用'.'分割,然后依次寻找路径。 -// -// 结构体属性可以使用结构体标签'alias'来匹配属性。 -// -// 如果匹配失败直接返回空值。 -func Get(i interface{}, key string) interface{} { - val, err := getValue(i, key, DefaultGetSetTags, false) - if err != nil { - return nil - } - return val.Interface() -} - -// GetWithTags 函数和Get函数相同,可以额外设置tags,同时会返回error。 -func GetWithTags(i interface{}, key string, tags []string, all bool) (interface{}, error) { - val, err := getValue(i, key, tags, false) - if err != nil { - return nil, err - } - return val.Interface(), nil -} - -// GetWithValue 函数和Get函数相同,可以允许查找私有属性并返回reflect.Value。 -func GetWithValue(i interface{}, key string, tags []string, all bool) (reflect.Value, error) { - return getValue(i, key, tags, all) -} - -func getValue(i interface{}, key string, tags []string, all bool) (reflect.Value, error) { - val := reflect.ValueOf(i) - if i == nil { - return val, ErrConverterInputDataNil - } - if key == "" { - return val, nil - } - s := &convertValue{ - All: all, - Keys: strings.Split(key, "."), - Tags: tags, - } - val, err := s.getValue(val) - if err != nil { - return val, err - } - return val, nil -} - -// 从目标类型获取字符串路径的属性 -func (v *convertValue) getValue(iValue reflect.Value) (reflect.Value, error) { - if len(v.Keys) == v.Index { - return iValue, nil - } - switch iValue.Kind() { - case reflect.Ptr, reflect.Interface: - if iValue.IsNil() { - return iValue, v.newGetError("is nil ptr or interface") - } - return v.getValue(iValue.Elem()) - case reflect.Struct: - return v.getStruct(iValue) - case reflect.Map: - return v.getMap(iValue) - case reflect.Array, reflect.Slice: - return v.getSlice(iValue) - } - return iValue, v.newGetError("not find sub path") -} - -// 处理结构体对象的读取 -func (v *convertValue) getStruct(iValue reflect.Value) (reflect.Value, error) { - // 查找key对应的属性索引,不存在返回-1。 - var index = getStructIndexOfTags(iValue.Type(), v.Keys[v.Index], v.Tags) - if index == -1 { - return iValue, v.newGetError("not field") - } - // 获取key对应结构的属性。 - structField := iValue.Field(index) - if structField.CanSet() || v.All { - v.Index++ - return v.getValue(structField) - } - return iValue, v.newGetError("field is not CanSet") -} - -// 处理map读取属性 -func (v *convertValue) getMap(iValue reflect.Value) (reflect.Value, error) { - // 检测map是否为空 - if iValue.IsNil() { - return iValue, v.newGetError("is nil map") - } - // 创建map需要的key - mapKey := reflect.New(iValue.Type().Key()).Elem() - err := setWithString(mapKey, v.Keys[v.Index]) - if err != nil { - return iValue, v.newGetError("map key is invalid") - } - - // 获得map的value, 如果值无效则返回空。 - mapvalue := iValue.MapIndex(mapKey) - if mapvalue.Kind() == reflect.Invalid { - return iValue, v.newGetError("map value is invalid") - } - v.Index++ - return v.getValue(mapvalue) -} - -// 处理数组切片读取属性 -func (v *convertValue) getSlice(iValue reflect.Value) (reflect.Value, error) { - // 检测切片是否为空 - if iValue.Kind() == reflect.Slice && iValue.IsNil() { - return iValue, v.newGetError("is nil slice") - } - // 检测索引是否存在 - index, err := strconv.Atoi(v.Keys[v.Index]) - if err != nil || index < 0 || iValue.Len() <= index { - return iValue, v.newGetError("slice index is invalid") - } - v.Index++ - return v.getValue(iValue.Index(index)) -} - -func (v *convertValue) newGetError(str string) error { - return fmt.Errorf(ErrFormatConverterGet, strings.Join(v.Keys[:v.Index+1], "."), str) -} - -// ConvertMapString 函数将一个map或struct转换成map[string]interface{}。 -func ConvertMapString(i interface{}) map[string]interface{} { - return ConvertMapStringWithTags(i, DefaultConvertTags) -} - -// ConvertMapStringWithTags 函数与ConvertMapString相同,允许使用额外的tags。 -func ConvertMapStringWithTags(i interface{}, tags []string) map[string]interface{} { - c := &converter{ - tags: tags, - results: make(map[reflect.Value]interface{}), - } - // 其他类型直接返回 - val, ok := c.convertMapString(reflect.ValueOf(i)).(map[string]interface{}) - if ok { - return val - } - return nil -} - -// 将一个map或结构体对象转换成map[string]interface{}返回。 -func (c *converter) convertMapString(iValue reflect.Value) interface{} { - result, ok := c.results[iValue] - if ok { - return result - } - switch iValue.Kind() { - // 接口类型解除引用 - case reflect.Interface: - // 空接口直接返回 - if iValue.Elem().Kind() == reflect.Invalid { - return iValue.Interface() - } - return c.convertMapString(iValue.Elem()) - // 指针类型解除引用 - case reflect.Ptr: - // 空指针直接返回 - if iValue.IsNil() { - return iValue.Interface() - } - return c.convertMapString(iValue.Elem()) - // 将map转换成map[string]interface{} - case reflect.Map: - val := make(map[string]interface{}) - c.results[iValue] = val - c.convertMapstrngMapToMapString(iValue, val) - return val - // 将结构体转换成map[string]interface{} - case reflect.Struct: - val := make(map[string]interface{}) - c.results[iValue] = val - c.convertMapstringStructToMapString(iValue, val) - return val - } - // 其他类型直接返回 - return iValue.Interface() -} - -// 结构体转换成map[string]interface{} -func (c *converter) convertMapstringStructToMapString(iValue reflect.Value, val map[string]interface{}) { - iType := iValue.Type() - // 遍历结构体的属性 - for i := 0; i < iType.NumField(); i++ { - fieldKey := iType.Field(i) - fieldValue := iValue.Field(i) - if fieldValue.CanSet() { - // map设置键位结构体的名称,值为结构体值转换,基本类型会直接返回。 - val[getStructNameOfTags(fieldKey, c.tags)] = c.convertMapString(fieldValue) - } - } -} - -// 将map转换成map[string]interface{} -func (c *converter) convertMapstrngMapToMapString(iValue reflect.Value, val map[string]interface{}) { - // 遍历map的全部keys - for _, key := range iValue.MapKeys() { - v := iValue.MapIndex(key) - // 设置新map的键为原map的字符串输出,未支持接口转换 - // 设置新map的值为原map匹配的值的转换,值为基本类型会直接返回。 - val[fmt.Sprint(key.Interface())] = c.convertMapString(v) - } -} - -// ConvertMap 函数将一个map或struct转换成map[interface{}]interface{}。 -func ConvertMap(i interface{}) map[interface{}]interface{} { - return ConvertMapWithTags(i, DefaultConvertTags) -} - -// ConvertMapWithTags 函数与ConvertMap相同,允许使用额外的tags。 -func ConvertMapWithTags(i interface{}, tags []string) map[interface{}]interface{} { - c := &converter{ - tags: tags, - results: make(map[reflect.Value]interface{}), - } - // 其他类型直接返回 - val, ok := c.convertMap(reflect.ValueOf(i)).(map[interface{}]interface{}) - if ok { - return val - } - return nil -} - -// 将一个map或结构体对象转换成map[interface{}]interface{}返回。 -func (c *converter) convertMap(iValue reflect.Value) interface{} { - result, ok := c.results[iValue] - if ok { - return result - } - switch iValue.Kind() { - case reflect.Interface: - if iValue.Elem().Kind() == reflect.Invalid { - return iValue.Interface() - } - return c.convertMap(iValue.Elem()) - case reflect.Ptr: - if iValue.IsNil() { - return iValue.Interface() - } - return c.convertMap(iValue.Elem()) - case reflect.Map: - val := make(map[interface{}]interface{}) - c.results[iValue] = val - c.convertMapMapToMap(iValue, val) - return val - case reflect.Struct: - val := make(map[interface{}]interface{}) - c.results[iValue] = val - c.convertMapStructToMap(iValue, val) - return val - } - return iValue.Interface() -} - -// 结构体转换成map[interface{}]interface{} -func (c *converter) convertMapStructToMap(iValue reflect.Value, val map[interface{}]interface{}) { - iType := iValue.Type() - // 遍历结构体的属性 - for i := 0; i < iType.NumField(); i++ { - fieldKey := iType.Field(i) - fieldValue := iValue.Field(i) - if fieldValue.CanSet() { - // map设置键位结构体的名称,值为结构体值转换,基本类型会直接返回。 - val[getStructNameOfTags(fieldKey, c.tags)] = c.convertMap(fieldValue) - } - } -} - -// 将map转换成map[interface{}]interface{} -func (c *converter) convertMapMapToMap(iValue reflect.Value, val map[interface{}]interface{}) { - // 遍历map的全部keys - for _, key := range iValue.MapKeys() { - v := iValue.MapIndex(key) - // 设置新map的键为原map的字符串输出,未支持接口转换 - // 设置新map的值为原map匹配的值的转换,值为基本类型会直接返回。 - val[key.Interface()] = c.convertMap(v) - } -} - -// checkValueIsZero 函数检测一个值是否为空, 修改go.1.13 refletv.Value.IsZero方法。 -func checkValueIsZero(iValue reflect.Value) bool { - switch iValue.Kind() { - case reflect.Bool: - return !iValue.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return iValue.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return iValue.Uint() == 0 - case reflect.Float32, reflect.Float64: - return math.Float64bits(iValue.Float()) == 0 - case reflect.Complex64, reflect.Complex128: - c := iValue.Complex() - return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0 - case reflect.String: - return iValue.Len() == 0 - case reflect.UnsafePointer: - // 兼容go1.9 - // if iValue.CanSet(){ - return iValue.Interface() == nil - //} - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: - return iValue.IsNil() - case reflect.Array: - for i := 0; i < iValue.Len(); i++ { - if !checkValueIsZero(iValue.Index(i)) { - return false - } - } - case reflect.Struct: - for i := 0; i < iValue.NumField(); i++ { - if !checkValueIsZero(iValue.Field(i)) { - return false - } - } - } - return true -} - -// 通过字符串获取结构体属性的索引 -func getStructIndexOfTags(iType reflect.Type, name string, tags []string) int { - // 遍历匹配 - for i := 0; i < iType.NumField(); i++ { - typeField := iType.Field(i) - // 字符串为结构体名称或结构体属性标签的值,则匹配返回索引。 - if typeField.Name == name { - return i - } - for _, tag := range tags { - if typeField.Tag.Get(tag) == name { - return i - } - } - } - return -1 -} - -func getStructNameOfTags(field reflect.StructField, tags []string) string { - for _, tag := range tags { - name := field.Tag.Get(tag) - if name != "" { - return name - } - } - return field.Name -} - -// getIndirectAllValue 函数获得解除引用的全部类型和值。 -func getIndirectAllValue(iValue reflect.Value) (types []reflect.Type, values []reflect.Value) { - for { - types = append(types, iValue.Type()) - values = append(values, iValue) - switch iValue.Kind() { - case reflect.Ptr, reflect.Interface: - if iValue.IsNil() { - return - } - iValue = iValue.Elem() - default: - return - } - } -} - -func setWithValue(sValue reflect.Value, tValue reflect.Value) error { - if sValue.Kind() == reflect.Ptr || sValue.Kind() == reflect.Interface || tValue.Kind() == reflect.Ptr || tValue.Kind() == reflect.Interface { - stypes, svalues := getIndirectAllValue(sValue) - ttypes, tvalues := getIndirectAllValue(tValue) - for i, ttype := range ttypes { - for j, stype := range stypes { - // 转换接口类型、相同类型、type别名类型 - if stype.ConvertibleTo(ttype) && tvalues[i].CanSet() { - return setWithValueData(svalues[j], tvalues[i]) - } - } - } - sValue = svalues[len(svalues)-1] - tValue = tvalues[len(tvalues)-1] - - // 目标类型如果是空指针,则尝试进行初始化并转换 - if tValue.Kind() == reflect.Ptr && tValue.IsNil() { - newValue := reflect.New(tValue.Type().Elem()) - err := setWithValue(sValue, newValue) - if err == nil { - tValue.Set(newValue) - } - return err - } - } - return setWithValueData(sValue, tValue) -} - -func setWithValueData(sValue reflect.Value, tValue reflect.Value) error { - sType := sValue.Type() - tType := tValue.Type() - switch { - case sType == tType: - tValue.Set(sValue) - return nil - case sType.Kind() == reflect.String: - return setWithString(tValue, sValue.String()) - case tType.Kind() == reflect.String: - tValue.SetString(getWithValueString(sType, sValue)) - return nil - case sType.ConvertibleTo(tType): - tValue.Set(sValue.Convert(tType)) - return nil - case sType.Kind() == reflect.Slice: - err := setWithValueData(reflect.Indirect(sValue.Index(0)), tValue) - if err == nil { - return nil - } - } - return fmt.Errorf(ErrFormatConverterSetWithValue, sValue.Type().String(), tValue.Type().String()) -} - -func getWithValueString(t reflect.Type, v reflect.Value) string { - if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { - switch t.Elem().Kind() { - case reflect.String: - if v.Len() > 0 { - return v.Index(0).String() - } - case reflect.Uint8, reflect.Int32: - return v.Convert(typeString).String() - } - } - return fmt.Sprintf("%#v", v.Interface()) -} - -// 将一个字符串赋值给对象 -func setWithString(iValue reflect.Value, val string) error { - val = strings.TrimSpace(val) - switch iValue.Kind() { - case reflect.Int: - return setIntField(val, 0, iValue) - case reflect.Int8: - return setIntField(val, 8, iValue) - case reflect.Int16: - return setIntField(val, 16, iValue) - case reflect.Int32: - return setIntField(val, 32, iValue) - case reflect.Int64: - return setIntField(val, 64, iValue) - case reflect.Uint: - return setUintField(val, 0, iValue) - case reflect.Uint8: - return setUintField(val, 8, iValue) - case reflect.Uint16: - return setUintField(val, 16, iValue) - case reflect.Uint32: - return setUintField(val, 32, iValue) - case reflect.Uint64: - return setUintField(val, 64, iValue) - case reflect.Bool: - return setBoolField(val, iValue) - case reflect.Float32: - return setFloatField(val, 32, iValue) - case reflect.Float64: - return setFloatField(val, 64, iValue) - case reflect.Complex64: - return setComplexField(val, 32, iValue) - case reflect.Complex128: - return setComplexField(val, 64, iValue) - // 目标类型是字符串直接设置 - case reflect.String: - iValue.SetString(val) - case reflect.Struct: - if iValue.Type().ConvertibleTo(typeTimeTime) { - return setTimeField(val, iValue) - } - return json.Unmarshal([]byte(val), iValue.Addr().Interface()) - case reflect.Slice: - switch iValue.Type().Elem().Kind() { - case reflect.Uint8, reflect.Int32: - iValue.Set(reflect.ValueOf(val).Convert(iValue.Type())) - default: - return json.Unmarshal([]byte(val), iValue.Addr().Interface()) - } - case reflect.Array, reflect.Map: - return json.Unmarshal([]byte(val), iValue.Addr().Interface()) - case reflect.Interface: - if iValue.Type() == typeInterface { - iValue.Set(reflect.ValueOf(val)) - } - // 目标类型是指针进行解引用然后赋值。 - case reflect.Ptr: - if !iValue.Elem().IsValid() { - iValue.Set(reflect.New(iValue.Type().Elem())) - } - return setWithString(iValue.Elem(), val) - default: - return fmt.Errorf(ErrFormatConverterSetStringUnknownType, iValue.Kind().String()) - } - return nil -} - -// 设置int类型属性 -func setIntField(val string, bitSize int, field reflect.Value) error { - if val == "" { - val = "0" - } - intVal, err := strconv.ParseInt(val, 10, bitSize) - // 兼容 time.Duration及衍生类型 - if err != nil && field.Kind() == reflect.Int64 { - var t time.Duration - t, err = time.ParseDuration(val) - if err != nil { - return err - } - intVal = int64(t) - } - if err == nil { - field.SetInt(intVal) - } - return err -} - -// 设置无符号整形属性 -func setUintField(val string, bitSize int, field reflect.Value) error { - if val == "" { - val = "0" - } - uintVal, err := strconv.ParseUint(val, 10, bitSize) - if err == nil { - field.SetUint(uintVal) - } - return err -} - -// 设置布尔类型属性 -func setBoolField(val string, field reflect.Value) error { - if val == "" { - val = "false" - } - boolVal, err := strconv.ParseBool(val) - if err == nil { - field.SetBool(boolVal) - } - return err -} - -// 设置复数 -func setComplexField(val string, bitSize int, field reflect.Value) error { - val = strings.TrimSuffix(strings.TrimSuffix(strings.TrimPrefix(val, "("), "i"), ")") - pos := strings.Index(val, "+") - if pos == -1 { - pos = len(val) - val += "+0" - } - - read, err := strconv.ParseFloat(val[:pos], bitSize) - if err != nil { - return err - - } - image, err := strconv.ParseFloat(val[pos+1:], bitSize) - if err != nil { - return err - } - - field.SetComplex(complex(read, image)) - return nil -} - -// 设置浮点类型属性 -func setFloatField(val string, bitSize int, field reflect.Value) error { - if val == "" { - val = "0.0" - } - floatVal, err := strconv.ParseFloat(val, bitSize) - if err == nil { - field.SetFloat(floatVal) - } - return err -} - -// timeformats 定义允许使用的时间格式。 -var timeformats = []string{ - time.ANSIC, - time.UnixDate, - time.RubyDate, - time.RFC822, - time.RFC822Z, - time.RFC850, - time.RFC1123, - time.RFC1123Z, - time.RFC3339, - time.RFC3339Nano, - time.Kitchen, - time.Stamp, - time.StampMilli, - time.StampMicro, - time.StampNano, - "2006-1-02", - "2006-01-02", - "15:04:05", - "2006-01-02 15:04:05", - "2006-01-02T15:04:05Z07:00", - "2006-01-02T15:04:05.999999999Z07:00", -} - -// TimeParse 方法通过解析内置支持的时间格式。 -func setTimeField(str string, iValue reflect.Value) (err error) { - var t time.Time - for _, f := range timeformats { - t, err = time.Parse(f, str) - if err == nil { - if iValue.Type() != typeTimeTime { - iValue.Set(reflect.ValueOf(t).Convert(iValue.Type())) - } else { - iValue.Set(reflect.ValueOf(t)) - } - return - } - } - return -} diff --git a/converter2.go b/converter2.go deleted file mode 100644 index e18fc6a..0000000 --- a/converter2.go +++ /dev/null @@ -1,190 +0,0 @@ -package eudore - -import ( - "fmt" - "reflect" - "unsafe" -) - -// ConvertTo 将一个对象属性复制给另外一个对象,可转换对象属性会覆盖原值。 -func ConvertTo(source interface{}, target interface{}) error { - return ConvertToWithTags(source, target, DefaultConvertTags) -} - -// ConvertToWithTags 函数与ConvertTo相同,允许使用额外的tags。 -func ConvertToWithTags(source interface{}, target interface{}, tags []string) error { - if source == nil { - return ErrConverterInputDataNil - } - if target == nil { - return ErrConverterTargetDataNil - } - - // 检测目标是指针类型。 - if reflect.TypeOf(target).Kind() != reflect.Ptr { - return ErrConverterInputDataNotPtr - } - - c := &convertMapping{ - Tags: tags, - Refs: make(map[unsafe.Pointer]reflect.Value), - } - return c.convertToData(reflect.ValueOf(source), reflect.ValueOf(target)) -} - -type convertMapping struct { - Tags []string - Refs map[unsafe.Pointer]reflect.Value -} - -func getValuePointer(iValue reflect.Value) unsafe.Pointer { - val := *(*innerValue)(unsafe.Pointer(&iValue)) - return val.ptr -} - -type innerValue struct { - _ *int - ptr unsafe.Pointer - flag uintptr -} - -func (c *convertMapping) convertToData(sValue reflect.Value, tValue reflect.Value) error { - switch sValue.Kind() { - case reflect.Ptr, reflect.Map, reflect.Interface: - if !sValue.IsNil() && tValue.CanSet() { - ref, ok := c.Refs[getValuePointer(sValue)] - if ok { - if ref.Type().ConvertibleTo(tValue.Type()) { - tValue.Set(ref.Convert(tValue.Type())) - return nil - } - } - } - } - - skind := sValue.Kind() - tkind := tValue.Kind() - switch { - case checkValueIsZero(sValue): - return nil - case sValue.Kind() == reflect.Interface: - c.Refs[getValuePointer(sValue)] = tValue - return c.convertToData(sValue.Elem(), tValue) - case tValue.Kind() == reflect.Interface: - if tValue.IsNil() { - newValue := reflect.New(sValue.Type()).Elem() - if newValue.Type().ConvertibleTo(tValue.Type()) { - err := c.convertToData(sValue, newValue) - if err == nil { - tValue.Set(newValue.Convert(tValue.Type())) - } - return err - } - } else { - return c.convertToData(sValue, tValue.Elem()) - } - case sValue.Kind() == reflect.Ptr: - c.Refs[getValuePointer(sValue)] = tValue - return c.convertToData(sValue.Elem(), tValue) - case tValue.Kind() == reflect.Ptr: - if tValue.IsNil() { - newValue := reflect.New(tValue.Type().Elem()) - err := c.convertToData(sValue, newValue.Elem()) - if err == nil { - tValue.Set(newValue) - } - return err - } - return c.convertToData(sValue, tValue.Elem()) - case skind == reflect.Map && tkind == reflect.Map: - c.convertToMapToMap(sValue, tValue) - case skind == reflect.Map && tkind == reflect.Struct: - c.convertToMapToStruct(sValue, tValue) - case skind == reflect.Struct && tkind == reflect.Map: - c.convertToStructToMap(sValue, tValue) - case skind == reflect.Struct && tkind == reflect.Struct: - c.convertToStructToStruct(sValue, tValue) - case (skind == reflect.Slice || skind == reflect.Array) && (tkind == reflect.Slice || tkind == reflect.Array): - c.convertToSlice(sValue, tValue) - default: - return setWithValueData(sValue, tValue) - } - return nil -} - -func (c *convertMapping) convertToMapToMap(sValue reflect.Value, tValue reflect.Value) { - tType := tValue.Type() - if tValue.IsNil() { - tValue.Set(reflect.MakeMap(tType)) - } - - // TODO: map to map - // c.Refs[getValuePointer(sValue)] = tValue - for _, key := range sValue.MapKeys() { - mapvalue := reflect.New(tType.Elem()).Elem() - if err := c.convertToData(sValue.MapIndex(key), mapvalue); err == nil { - tValue.SetMapIndex(key, mapvalue) - } - } - -} - -func (c *convertMapping) convertToMapToStruct(sValue reflect.Value, tValue reflect.Value) { - tType := tValue.Type() - for _, key := range sValue.MapKeys() { - index := getStructIndexOfTags(tType, fmt.Sprint(key.Interface()), c.Tags) - if index == -1 || !tValue.Field(index).CanSet() { - continue - } - c.convertToData(sValue.MapIndex(key), tValue.Field(index)) - } -} - -func (c *convertMapping) convertToStructToMap(sValue reflect.Value, tValue reflect.Value) { - sType := sValue.Type() - tType := tValue.Type() - if tValue.IsNil() { - tValue.Set(reflect.MakeMap(tType)) - } - for i := 0; i < sType.NumField(); i++ { - if checkValueIsZero(sValue.Field(i)) || !sValue.Field(i).CanSet() { - continue - } - - mapvalue := reflect.New(tType.Elem()).Elem() - if err := c.convertToData(sValue.Field(i), mapvalue); err == nil { - tValue.SetMapIndex(reflect.ValueOf(sType.Field(i).Name), mapvalue) - } - } -} - -func (c *convertMapping) convertToStructToStruct(sValue reflect.Value, tValue reflect.Value) { - sType := sValue.Type() - tType := tValue.Type() - for i := 0; i < sType.NumField(); i++ { - if checkValueIsZero(sValue.Field(i)) { - continue - } - - index := getStructIndexOfTags(tType, sType.Field(i).Name, c.Tags) - if index == -1 || !tValue.Field(index).CanSet() { - continue - } - c.convertToData(sValue.Field(i), tValue.Field(index)) - } -} - -func (c *convertMapping) convertToSlice(sValue reflect.Value, tValue reflect.Value) { - num := sValue.Len() - tValue.Len() - if num > 0 && tValue.CanSet() { - tValue.Set(reflect.AppendSlice(tValue, reflect.MakeSlice(tValue.Type(), num, num))) - } - if num > 0 { - num = tValue.Len() - } else { - num = sValue.Len() - } - for i := 0; i < num; i++ { - c.convertToData(sValue.Index(i), tValue.Index(i)) - } -} diff --git a/daemon/command.go b/daemon/command.go new file mode 100644 index 0000000..6ec01fb --- /dev/null +++ b/daemon/command.go @@ -0,0 +1,232 @@ +package daemon + +/* +利用系统信号进制,执行start、daemon、stop、status、restart命令来操作进程。 +进程pid存储在pid文件中。 +*/ + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "os/exec" + "strconv" + "strings" + "syscall" + "time" + + "github.com/eudore/eudore" +) + +const ( + CommandStart = "start" + CommandDaemon = "daemon" + CommandStatus = "status" + CommandStop = "stop" + CommandRestart = "restart" + CommandDisable = "disable" +) + +// Command is a command parser that performs the corresponding behavior based on the current command. +// +// Command 对象是一个命令解析器,根据当前命令执行对应行为。 +type Command struct { + Command string + Pidfile string + Args []string + Envs []string + Print func(string, ...any) +} + +// Run 方法启动Command解析。 +func (cmd *Command) Run(ctx context.Context) (err error) { + switch cmd.Command { + case CommandStart: + return cmd.Start(ctx) + case CommandDaemon: + return cmd.Daemon(ctx) + case CommandStatus: + err = cmd.Status() + case CommandStop: + err = cmd.Stop() + case CommandRestart: + err = cmd.Restart() + default: + err = errors.New("undefined command " + cmd.Command) + cmd.Print("undefined command %s, support command: start/status/stop/restart/daemon.", cmd.Command) + } + + if err != nil { + cmd.Print("%s is false, error: %w.", cmd.Command, err) + return err + } + + pid, _ := cmd.readpid() + cmd.Print("%s is true, pid is %d, pidfile in %s.", cmd.Command, pid, cmd.Pidfile) + cmd.Wait(pid) + + return fmt.Errorf("daemon is %s %w", cmd.Command, context.Canceled) +} + +func (cmd *Command) Wait(p int) { + t := eudore.GetAnyByString(os.Getenv(eudore.EnvEudoreDaemonTimeout), 60) + if t < 0 || (cmd.Command != CommandStop && cmd.Command != CommandRestart) { + return + } + + for i := 0; i <= t; i++ { + pid, err := cmd.readpid() + switch { + case cmd.Command == CommandStop && err != nil: + cmd.Print("stop successfully, wait time %ds.", i) + return + case err != nil: + cmd.Print("%s read pid error: %w.", cmd.Command, err) + case cmd.Command == CommandRestart && pid != p: + cmd.Print("restart successfully, new pid is %d, wait time %ds.", pid, i) + return + } + time.Sleep(time.Second) + } + if t > 0 { + cmd.Print("%s failed, wait time %ds timeout.", cmd.Command, t) + } +} + +// Start execute the startup function and write the pid to the file. +// +// Start 函数执行启动函数,并将pid写入文件。 +func (cmd *Command) Start(ctx context.Context) error { + // 测试文件是否被锁定 + pid, err := cmd.readpid() + if err != nil && !os.IsNotExist(err) { + return err + } + + restartid := eudore.GetAny[int](os.Getenv(eudore.EnvEudoreDaemonRestartID)) + if pid != 0 && pid == restartid { + return nil + } + + err = cmd.Status() + if err == nil { + return fmt.Errorf("process exites pid %d", pid) + } + + // 写入pid + return cmd.writepid(ctx) +} + +// Daemon Start the process in the background. If it is not started in the background, create a background process. +// +// Daemon 函数后台启动进程。若不是后台启动,则创建一个后台进程。 +func (cmd *Command) Daemon(ctx context.Context) error { + if eudore.GetAny[bool](os.Getenv(eudore.EnvEudoreDaemonEnable)) { + return cmd.Start(ctx) + } + + // 测试文件是否被锁定 + pid, err := cmd.readpid() + if err != nil && !os.IsNotExist(err) { + return err + } + err = cmd.Status() + if err == nil { + return fmt.Errorf("process exites pid %d", pid) + } + + fork := exec.Command(os.Args[0], os.Args[1:]...) + fork.Args = append(fork.Args, cmd.Args...) + fork.Env = append(os.Environ(), fmt.Sprintf("%s=%d", eudore.EnvEudoreDaemonEnable, 1)) + fork.Env = append(fork.Env, cmd.Envs...) + fork.Stdout = os.Stdout + err = fork.Start() + if err != nil { + return err + } + return fmt.Errorf("daemon start %w", context.Canceled) +} + +// Status 函数调用系统命令,想进程发送信号 0。 +func (cmd *Command) Status() error { + return cmd.ExecSignal(syscall.Signal(0x00)) +} + +// Stop 函数调用系统命令,想进程发送信号syscall.SIGTERM。 +func (cmd *Command) Stop() error { + return cmd.ExecSignal(syscall.Signal(0x0f)) +} + +// Reload 函数调用系统命令,想进程发送信号syscall.SIGUSR1。 +func (cmd *Command) Reload() error { + return cmd.ExecSignal(syscall.Signal(0x0a)) +} + +// Restart 函数调用系统命令,想进程发送信号syscall.SIGUSR2。 +func (cmd *Command) Restart() error { + return cmd.ExecSignal(syscall.Signal(0x0c)) +} + +// ExecSignal 函数向pidfile内的进程发送指定命令。 +func (cmd *Command) ExecSignal(sig os.Signal) error { + pid, err := cmd.readpid() + if err != nil { + return err + } + + process, err := os.FindProcess(pid) + if err != nil { + return fmt.Errorf("find process %d error: %w", pid, err) + } + + err = process.Signal(sig) + if err != nil { + os.Remove(cmd.Pidfile) + return err + } + return nil +} + +// Read the value in the pid file. +// +// 读取pid文件内的值。 +func (cmd *Command) readpid() (int, error) { + file, err := os.OpenFile(cmd.Pidfile, os.O_RDONLY, 0o644) + if err != nil { + return 0, err + } + defer file.Close() + + id, err := io.ReadAll(file) + if err != nil { + return 0, err + } + return strconv.Atoi(strings.TrimSpace(string(id))) +} + +// Open and lock the pid file and write the value of pid. +// +// 打开并锁定pid文件,写入pid的值。 +func (cmd *Command) writepid(ctx context.Context) (err error) { + file, err := os.OpenFile(cmd.Pidfile, os.O_WRONLY|os.O_CREATE, 0o644) + if err != nil { + return + } + defer file.Close() + + _, err = fmt.Fprintf(file, "%d", os.Getpid()) + if err != nil { + return + } + go func() { + // 关闭删除pid文件 + <-ctx.Done() + pid, err := cmd.readpid() + if err == nil && pid == os.Getpid() { + os.Remove(cmd.Pidfile) + } + }() + return nil +} diff --git a/daemon/daemon.go b/daemon/daemon.go new file mode 100644 index 0000000..a95bbe3 --- /dev/null +++ b/daemon/daemon.go @@ -0,0 +1,111 @@ +/* +Package daemon 实现应用进程启动命令、后台启动、信号处理、热重启的代码支持。 + +# 启动命令 + + app --command=start/status/stop/restart/deamon/disable --pidfile=/run/run/pidfile + +command: + + start 写入pid前台启动 + daemon 写入pid后台启动 + status 读取pid判断进程存在 + stop 读取pid发送syscall.SIGTERM信号(15) + restart 读取pid发送syscall.SIGUSR2信号(12) + disable 跳过启动命令处理 + +# 后台启动 + + func main() { + daemon.StartDaemon() + } + +# 信号处理 + +# 热重启 + +使用command组件或kill命令发送SIGUSR2信号。 + +父进程接受SIGUSR2信号后,传递当前Listen FD和ppid后台启动子进程; +子进程启动初始化后完成向父进程发送SIGTERM信号关闭父进程。 +*/ +package daemon + +import ( + "context" + "fmt" + "os" + "os/exec" + "syscall" + + "github.com/eudore/eudore" +) + +// NewParseCommand 函数创建Command配置解析函数。 +func NewParseDaemon(app *eudore.App) eudore.ConfigParseFunc { + return func(ctx context.Context, conf eudore.Config) error { + sig := &Signal{ + Chan: make(chan os.Signal), + Funcs: make(map[os.Signal][]SignalFunc), + } + sig.Register(syscall.Signal(0x02), AppStop) + sig.Register(syscall.Signal(0x0f), AppStop) + app.SetValue(eudore.ContextKeyDaemonSignal, sig) + go sig.Run(ctx) + + cmd := &Command{ + Command: eudore.GetAny(conf.Get("command"), CommandStart), + Pidfile: eudore.GetAny(conf.Get("pidfile"), eudore.DefaultDaemonPidfile), + Print: func(format string, args ...any) { + fmt.Printf(format+"\r\n", args...) //nolint:forbidigo + }, + } + if cmd.Command == CommandDisable { + return nil + } + app.Infof("command is %s, pidfile in %s, process pid is %d.", cmd.Command, cmd.Pidfile, os.Getpid()) + if cmd.Command != CommandStart && cmd.Command != CommandDaemon { + app.SetValue(eudore.ContextKeyLogger, eudore.DefaultLoggerNull) + } + + app.SetValue(eudore.ContextKeyDaemonCommand, cmd) + return cmd.Run(ctx) + } +} + +func NewParseRestart() eudore.ConfigParseFunc { + return func(ctx context.Context, conf eudore.Config) error { + sig, ok := ctx.Value(eudore.ContextKeyDaemonSignal).(*Signal) + if ok { + sig.Register(syscall.Signal(0x0c), AppRestart) + } + + cmd, ok := ctx.Value(eudore.ContextKeyDaemonCommand).(*Command) + if ok { + restartid := eudore.GetAny[int](os.Getenv(eudore.EnvEudoreDaemonRestartID)) + if restartid == 0 { + return nil + } + err := cmd.ExecSignal(syscall.Signal(0x0f)) + if err != nil { + return err + } + return cmd.writepid(ctx) + } + return nil + } +} + +// StartDaemon 函数直接后台启动程序。 +func StartDaemon(envs ...string) { + if eudore.GetAny[bool](os.Getenv(eudore.EnvEudoreDaemonEnable)) { + return + } + + cmd := exec.Command(os.Args[0], os.Args[1:]...) + cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%d", eudore.EnvEudoreDaemonEnable, 1)) + cmd.Env = append(cmd.Env, envs...) + cmd.Stdout = os.Stdout + _ = cmd.Start() + os.Exit(0) +} diff --git a/daemon/restart.go b/daemon/restart.go new file mode 100644 index 0000000..faa797f --- /dev/null +++ b/daemon/restart.go @@ -0,0 +1,114 @@ +package daemon + +import ( + "context" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + + "github.com/eudore/eudore" +) + +var ( + listeners = map[string]net.Listener{} + listenersfd = map[string]uintptr{} +) + +//nolint:gochecknoinits +func init() { + for i, addr := range strings.Split(os.Getenv(eudore.EnvEudoreDaemonListeners), " ,") { + if addr == "" { + continue + } + listenersfd[addr] = uintptr(i + 3) + } + + listen := eudore.DefaultServerListen + eudore.DefaultServerListen = func(network, address string) (net.Listener, error) { + addr := fmt.Sprintf("%s://%s", network, address) + var ln net.Listener + var err error + + fd, ok := listenersfd[addr] + if ok { + ln, err = net.FileListener(os.NewFile(fd, "")) + } else { + ln, err = listen(network, address) + } + + if err == nil { + listeners[addr] = ln + } + return ln, err + } +} + +func AppStop(ctx context.Context) error { + app, ok := ctx.Value(eudore.ContextKeyApp).(*eudore.App) + if ok { + app.CancelFunc() + } + return nil +} + +type filer interface { + File() (*os.File, error) +} + +func AppRestart(ctx context.Context) error { + path := os.Args[0] + dir, err := os.Getwd() + if err != nil { + return err + } + if filepath.Base(path) == path { + path, err = exec.LookPath(path) + if err != nil { + return err + } + } + + // get addrs and socket listen fds + addrs := make([]string, 0, len(listeners)) + files := make([]*os.File, 0, len(listeners)) + for addr, ln := range listeners { + filer, ok := ln.(filer) + if ok { + fd, err := filer.File() + if err != nil { + return err + } + addrs = append(addrs, addr) + files = append(files, fd) + syscall.CloseOnExec(int(fd.Fd())) + defer fd.Close() + } + } + + // set graceful restart env flag + envs := []string{} + for _, value := range os.Environ() { + if !strings.HasPrefix(value, "EUDORE_DAEMON_") { + envs = append(envs, value) + } + } + envs = append(envs, + fmt.Sprintf("%s=%d", eudore.EnvEudoreDaemonEnable, 1), + fmt.Sprintf("%s=%d", eudore.EnvEudoreDaemonRestartID, os.Getpid()), + fmt.Sprintf("%s=%s", eudore.EnvEudoreDaemonListeners, strings.Join(addrs, " ,")), + ) + + process, err := os.StartProcess(path, os.Args, &os.ProcAttr{ + Dir: dir, + Env: envs, + Files: append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, files...), + }) + if err == nil { + eudore.NewLoggerWithContext(ctx).Infof("eudore start new process %d", process.Pid) + } + return err +} diff --git a/daemon/signal.go b/daemon/signal.go new file mode 100644 index 0000000..9ae0c20 --- /dev/null +++ b/daemon/signal.go @@ -0,0 +1,70 @@ +package daemon + +import ( + "context" + "os" + "os/signal" + "sync" + + "github.com/eudore/eudore" +) + +// Signal handle func. +type SignalFunc func(context.Context) error + +type Signal struct { + sync.Mutex + Chan chan os.Signal + Funcs map[os.Signal][]SignalFunc +} + +func (sig *Signal) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case s := <-sig.Chan: + log := eudore.NewLoggerWithContext(ctx) + log.Infof("eudore accept signal: %s", s) + err := sig.Handle(ctx, s) + if err != nil { + log.Errorf("eudore handle signal %s error: %v", s, err) + } + } + } +} + +func (sig *Signal) Register(s os.Signal, fn SignalFunc) { + sig.Lock() + defer sig.Unlock() + if fn == nil { + delete(sig.Funcs, s) + } else { + sig.Funcs[s] = append(sig.Funcs[s], fn) + } + if len(sig.Funcs[s]) <= 1 { + sig.Notify() + } +} + +func (sig *Signal) Handle(ctx context.Context, s os.Signal) error { + sig.Lock() + defer sig.Unlock() + for _, fn := range sig.Funcs[s] { + if err := fn(ctx); err != nil { + return err + } + } + return nil +} + +func (sig *Signal) Notify() { + sigs := make([]os.Signal, 0, 4) + for key := range sig.Funcs { + sigs = append(sigs, key) + } + signal.Stop(sig.Chan) + if sigs != nil { + signal.Notify(sig.Chan, sigs...) + } +} diff --git a/database.go b/database.go index 1b951e5..28fab69 100644 --- a/database.go +++ b/database.go @@ -5,7 +5,7 @@ import ( "database/sql" ) -// Database 定义数据库操作方法 +// Database 定义数据库操作方法。 type Database interface { AutoMigrate(interface{}) error Metadata(interface{}) interface{} @@ -25,11 +25,34 @@ type DatabaseStmt interface { type DatabaseBuilder interface { Context() context.Context DriverName() string + Metadata(interface{}) interface{} WriteStmts(...interface{}) Result() (string, []interface{}, error) } -// NewDatabaseStd 方法创建一个空Database。 -func NewDatabaseStd(config interface{}) Database { +// NewDatabase 方法创建一个空Database。 +func NewDatabase(interface{}) Database { return nil } + +type stmtContextRuntime struct { + Context Context + Stmt DatabaseStmt +} + +func NewDatabaseRuntime(ctx Context, stmt DatabaseStmt) DatabaseStmt { + return stmtContextRuntime{ctx, stmt} +} + +var conetxtIDKeys = [...]string{HeaderXTraceID, HeaderXRequestID} + +func (stmt stmtContextRuntime) Build(builder DatabaseBuilder) { + h := stmt.Context.Response().Header() + for _, key := range conetxtIDKeys { + id := h.Get(key) + if id != "" { + builder.WriteStmts("-- "+id+"\r\n", stmt.Stmt) + return + } + } +} diff --git a/funccreator.go b/funccreator.go new file mode 100644 index 0000000..54508ab --- /dev/null +++ b/funccreator.go @@ -0,0 +1,653 @@ +package eudore + +import ( + "context" + "fmt" + "reflect" + "sort" + "strings" + "sync" +) + +const ( + FuncCreateInvalid FuncCreateKind = iota + FuncCreateString + FuncCreateInt + FuncCreateUint + FuncCreateFloat + FuncCreateBool + FuncCreateAny + FuncCreateSetString + FuncCreateSetInt + FuncCreateSetUint + FuncCreateSetFloat + FuncCreateSetBool + FuncCreateSetAny + FuncCreateNumber = FuncCreateAny +) + +// FuncCreateKind 定义FuncCreator可以创建的函数类型。 +type FuncCreateKind uint8 + +// FuncCreator 定义校验函数构造器,默认由RouterStd、validate、filter使用。 +type FuncCreator interface { + RegisterFunc(string, ...any) error + CreateFunc(FuncCreateKind, string) (any, error) + List() []string +} + +type funcTypeParams interface { + string | int | uint | float64 | bool | any +} + +type funcCreatorBase struct { + String typeCreator[func(string) bool] + Int typeCreator[func(int) bool] + Uint typeCreator[func(uint) bool] + Float typeCreator[func(float64) bool] + Bool typeCreator[func(bool) bool] + Any typeCreator[func(any) bool] + SetString typeCreator[func(string) string] + SetInt typeCreator[func(int) int] + SetUint typeCreator[func(uint) uint] + SetFloat typeCreator[func(float64) float64] + SetAny typeCreator[func(any) any] + Errors map[string]string +} + +type MetadataFuncCreator struct { + Health bool `alias:"health" json:"health" xml:"health" yaml:"health"` + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` + Funcs []string `alias:"funcs" json:"funcs" xml:"funcs" yaml:"funcs"` + Exprs []string `alias:"exprs,omitempty" json:"exprs,omitempty" xml:"exprs,omitempty" yaml:"exprs,omitempty"` + Errors []string `alias:"errors,omitempty" json:"errors,omitempty" xml:"errors,omitempty" yaml:"errors,omitempty"` +} + +// FuncRunner 定义创建函数类型和地址信息。 +type FuncRunner struct { + Kind FuncCreateKind + Func any +} + +// NewFuncCreator 函数创建默认FuncCreator并加载默认规则。 +func NewFuncCreator() FuncCreator { + fc := &funcCreatorBase{ + String: newTypeCreator[func(string) bool](), + Int: newTypeCreator[func(int) bool](), + Uint: newTypeCreator[func(uint) bool](), + Float: newTypeCreator[func(float64) bool](), + Bool: newTypeCreator[func(bool) bool](), + Any: newTypeCreator[func(any) bool](), + SetString: newTypeCreator[func(string) string](), + SetInt: newTypeCreator[func(int) int](), + SetUint: newTypeCreator[func(uint) uint](), + SetFloat: newTypeCreator[func(float64) float64](), + SetAny: newTypeCreator[func(any) any](), + Errors: make(map[string]string), + } + loadDefaultFuncDefine(fc) + return fc +} + +// NewFuncCreatorExpr 函数创建一个支持AND OR NOT关系表达式解析的FuncCreator。 +func NewFuncCreatorExpr() FuncCreator { + return &funcCreatorExpr{ + data: NewFuncCreator().(*funcCreatorBase), + parser: newFcExprParser(), + } +} + +// NewFuncCreatorWithContext 函数从环境上下文创建FuncCreator。 +func NewFuncCreatorWithContext(ctx context.Context) FuncCreator { + fc, ok := ctx.Value(ContextKeyFuncCreator).(FuncCreator) + if ok { + return fc + } + return DefaultFuncCreator +} + +// RegisterFunc 函数给一个名称注册多个类型的的ValidateFunc或ValidateNewFunc。 +// +//nolint:cyclop,gocyclo +func (fc *funcCreatorBase) RegisterFunc(name string, funcs ...any) error { + for i := range funcs { + switch fn := funcs[i].(type) { + case func(string) bool: + fc.String.Register(name, fn) + case func(int) bool: + fc.Int.Register(name, fn) + case func(uint) bool: + fc.Uint.Register(name, fn) + case func(float64) bool: + fc.Float.Register(name, fn) + case func(bool) bool: + fc.Bool.Register(name, fn) + case func(any) bool: + fc.Any.Register(name, fn) + case func(string) (func(string) bool, error): + fc.String.RegisterNew(name, fn) + case func(string) (func(uint) bool, error): + fc.Uint.RegisterNew(name, fn) + case func(string) (func(int) bool, error): + fc.Int.RegisterNew(name, fn) + case func(string) (func(float64) bool, error): + fc.Float.RegisterNew(name, fn) + case func(string) (func(bool) bool, error): + fc.Bool.RegisterNew(name, fn) + case func(string) (func(any) bool, error): + fc.Any.RegisterNew(name, fn) + case func(string) string: + fc.SetString.Register(name, fn) + case func(int) int: + fc.SetInt.Register(name, fn) + case func(uint) uint: + fc.SetUint.Register(name, fn) + case func(float64) float64: + fc.SetFloat.Register(name, fn) + case func(any) any: + fc.SetAny.Register(name, fn) + case func(string) (func(string) string, error): + fc.SetString.RegisterNew(name, fn) + case func(string) (func(uint) uint, error): + fc.SetUint.RegisterNew(name, fn) + case func(string) (func(int) int, error): + fc.SetInt.RegisterNew(name, fn) + case func(string) (func(float64) float64, error): + fc.SetFloat.RegisterNew(name, fn) + case func(string) (func(any) any, error): + fc.SetAny.RegisterNew(name, fn) + default: + return fc.appendError(name, fmt.Errorf(ErrFormatFuncCreatorRegisterInvalidType, name, fn)) + } + } + return nil +} + +// CreateFunc 方法感觉类型和名称创建校验函数。 +// +// 不支持动态创建具有NOT AND OR关系表达式函数,闭包影响性能。 +func (fc *funcCreatorBase) CreateFunc(kind FuncCreateKind, name string) (fn any, err error) { + switch kind { + case FuncCreateString: + fn, err = fc.String.Create(name) + case FuncCreateInt: + fn, err = fc.Int.Create(name) + case FuncCreateUint: + fn, err = fc.Uint.Create(name) + case FuncCreateFloat: + fn, err = fc.Float.Create(name) + case FuncCreateBool, FuncCreateSetBool: + fn, err = fc.Bool.Create(name) + case FuncCreateAny: + fn, err = fc.Any.Create(name) + case FuncCreateSetString: + fn, err = fc.SetString.Create(name) + case FuncCreateSetInt: + fn, err = fc.SetInt.Create(name) + case FuncCreateSetUint: + fn, err = fc.SetUint.Create(name) + case FuncCreateSetFloat: + fn, err = fc.SetFloat.Create(name) + case FuncCreateSetAny: + fn, err = fc.SetAny.Create(name) + default: + err = fmt.Errorf("invalid func kind %d", kind) + } + if err != nil { + return nil, fc.appendError(kind.String()+name, fmt.Errorf("funcCreator create kind %s func %s err: %w", kind, name, err)) + } + return fn, nil +} + +func (fc *funcCreatorBase) appendError(key string, err error) error { + fc.Bool.Lock() + fc.Errors[key] = err.Error() + fc.Bool.Unlock() + return err +} + +func (fc *funcCreatorBase) List() []string { + names := make([]string, 0, 128) + names = fc.String.List(names) + names = fc.Int.List(names) + names = fc.Uint.List(names) + names = fc.Float.List(names) + names = fc.Bool.List(names) + names = fc.Any.List(names) + names = fc.SetString.List(names) + names = fc.SetInt.List(names) + names = fc.SetUint.List(names) + names = fc.SetFloat.List(names) + names = fc.SetAny.List(names) + sort.Strings(names) + return names +} + +func (fc *funcCreatorBase) Metadata() any { + errs := make([]string, 0, len(fc.Errors)) + fc.Bool.RLock() + defer fc.Bool.RUnlock() + for _, v := range fc.Errors { + errs = append(errs, v) + } + return MetadataFuncCreator{ + Health: len(errs) == 0, + Name: "eudore.funcCreatorBase", + Funcs: fc.List(), + Errors: errs, + } +} + +type typeCreator[T any] struct { + sync.RWMutex + Values map[string]T + Constructor map[string]func(string) (T, error) +} + +func newTypeCreator[T any]() typeCreator[T] { + return typeCreator[T]{ + Values: make(map[string]T), + Constructor: make(map[string]func(string) (T, error)), + } +} + +func (tc *typeCreator[T]) Register(name string, fn T) { + tc.Lock() + tc.Values[name] = fn + tc.Unlock() +} + +func (tc *typeCreator[T]) RegisterNew(name string, fn func(string) (T, error)) { + tc.Lock() + tc.Constructor[name] = fn + tc.Unlock() +} + +func (tc *typeCreator[T]) Get(fullname string) (T, bool) { + tc.RLock() + fn, ok := tc.Values[fullname] + tc.RUnlock() + return fn, ok +} + +func (tc *typeCreator[T]) Create(fullname string) (T, error) { + tc.RLock() + fn, ok := tc.Values[fullname] + tc.RUnlock() + if ok { + return fn, nil + } + + name, arg := getFuncNameArg(fullname) + if arg != "" { + tc.RLock() + fnnews, ok := tc.Constructor[name] + tc.RUnlock() + if ok { + fn, err := fnnews(arg) + if err == nil { + tc.Register(fullname, fn) + } + return fn, err + } + } + return fn, ErrFuncCreatorNotFunc +} + +func (tc *typeCreator[T]) List(names []string) []string { + tc.RLock() + defer tc.RUnlock() + for key, fn := range tc.Values { + names = append(names, fmt.Sprintf("%s: %T", key, fn)) + } + for key, fn := range tc.Constructor { + names = append(names, fmt.Sprintf("%s: %T", key, fn)) + } + return names +} + +func getFuncNameArg(name string) (string, string) { + for i, b := range name { + // ! [0-9A-Za-z] + if b < 0x30 || (0x39 < b && b < 0x41) || (0x5A < b && b < 0x61) || 0x7A < b { + return name[:i], name[i:] + } + } + return name, "" +} + +type funcCreatorExpr struct { + data *funcCreatorBase + parser *fcExprParser +} + +func (fc *funcCreatorExpr) RegisterFunc(name string, funcs ...any) error { + return fc.data.RegisterFunc(name, funcs...) +} + +func (fc *funcCreatorExpr) CreateFunc(kind FuncCreateKind, name string) (any, error) { + if kind < FuncCreateSetString && (strings.Contains(name, "NOT") || + strings.Contains(name, "AND") || strings.Contains(name, "OR")) { + var fn any + var err error + switch kind { + case FuncCreateString: + fn, err = createFunc(&fc.data.String, name, fc.parser.parse) + case FuncCreateInt: + fn, err = createFunc(&fc.data.Int, name, fc.parser.parse) + case FuncCreateUint: + fn, err = createFunc(&fc.data.Uint, name, fc.parser.parse) + case FuncCreateFloat: + fn, err = createFunc(&fc.data.Float, name, fc.parser.parse) + case FuncCreateBool, FuncCreateSetBool: + fn, err = createFunc(&fc.data.Bool, name, fc.parser.parse) + case FuncCreateAny: + fn, err = createFunc(&fc.data.Any, name, fc.parser.parse) + } + if err != nil { + return nil, fc.data.appendError(kind.String()+name, err) + } + return fn, nil + } + return fc.data.CreateFunc(kind, name) +} + +func createFunc[T funcTypeParams](tc *typeCreator[func(T) bool], name string, parser fcExprFunc) (func(T) bool, error) { + fn, ok := tc.Get(name) + if ok { + return fn, nil + } + + expr, s := parser(name) + if s != "" { + return nil, fmt.Errorf("funcCreatorExpr not parse: %s, pos in %d", name, len(name)-len(s)) + } + + fn, err := createExpr(tc, expr) + if _, isstr := expr.(string); err == nil && !isstr { + tc.Register(name, fn) + } + return fn, err +} + +func createExpr[T funcTypeParams](tc *typeCreator[func(T) bool], expr any) (func(T) bool, error) { + switch e := expr.(type) { + case fcExprNot: + fn, err := createExpr(tc, e.Expr) + if err != nil { + return nil, err + } + return func(t T) bool { + return !fn(t) + }, nil + case fcExprAnd: + fns := make([]func(T) bool, len(e.Exprs)) + for i := range e.Exprs { + fn, err := createExpr(tc, e.Exprs[i]) + if err != nil { + return nil, err + } + fns[i] = fn + } + return func(t T) bool { + for i := range fns { + if !fns[i](t) { + return false + } + } + return true + }, nil + case fcExprOr: + fns := make([]func(T) bool, len(e.Exprs)) + for i := range e.Exprs { + fn, err := createExpr(tc, e.Exprs[i]) + if err != nil { + return nil, err + } + fns[i] = fn + } + return func(t T) bool { + for i := range fns { + if fns[i](t) { + return true + } + } + return false + }, nil + default: + return tc.Create(expr.(string)) + } +} + +func (fc *funcCreatorExpr) List() []string { + return fc.data.List() +} + +func (fc *funcCreatorExpr) Metadata() any { + meta := fc.data.Metadata().(MetadataFuncCreator) + meta.Name = "eudore.funcCreatorExpr" + var funcs, exprs []string + for _, f := range meta.Funcs { + if strings.Contains(f, "NOT") || strings.Contains(f, "AND") || strings.Contains(f, "OR") { + exprs = append(exprs, f) + } else { + funcs = append(funcs, f) + } + } + meta.Funcs = funcs + meta.Exprs = exprs + return meta +} + +type ( + fcExprFunc func(string) (any, string) + fcExprParser struct { + Parsers [][]fcExprFunc + Handler []func([]any) any + parse fcExprFunc + } + fcExprNot struct{ Expr any } + fcExprAnd struct{ Exprs []any } + fcExprOr struct{ Exprs []any } +) + +func newFcExprParser() *fcExprParser { + p := &fcExprParser{Handler: []func(data []any) any{ + func(data []any) any { return fcExprOr{fcExprData(data, false)} }, + func(data []any) any { return fcExprAnd{fcExprData(data, true)} }, + func(data []any) any { return fcExprNot{data[1]} }, + func(data []any) any { return data[1] }, + func(data []any) any { return data[0] }, + }} + p0, p1, p2 := p.p(0), p.p(1), p.p(2) + p.parse = p0 + p.Parsers = [][]fcExprFunc{ + {p1, fcExprMatch("OR"), p0}, + {p2, fcExprMatch("AND"), p1}, + {fcExprMatch("NOT"), p2}, + {fcExprMatch("("), p0, fcExprMatch(")")}, + {fcExprVal}, + } + return p +} + +func (p *fcExprParser) p(start int) fcExprFunc { + return func(s string) (any, string) { + for i, fns := range p.Parsers[start:] { + var val any + var vals []any + str := s + for _, fn := range fns { + val, str = fn(strings.TrimSpace(str)) + if val == nil { + break + } + vals = append(vals, val) + } + if len(vals) == len(fns) { + return p.Handler[i+start](vals), str + } + } + return nil, s + } +} + +func fcExprVal(s string) (any, string) { + l := len(s) + for _, sub := range []string{"NOT", "AND", "OR", ")"} { + pos := strings.Index(s[:l], sub) + if pos != -1 { + l = pos + } + } + s1, s2 := strings.TrimSpace(s[:l]), s[l:] + if s1 == "" { + return nil, s + } + return s1, s2 +} + +var exprEnd = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1, '(': 1} + +func fcExprMatch(str string) fcExprFunc { + return func(s string) (any, string) { + if strings.HasPrefix(s, str) { + if len(str) == 1 || (str != s && exprEnd[s[len(str)]] == 1) { + return "", s[len(str):] + } + } + return nil, s + } +} + +func fcExprData(data []any, and bool) []any { + d := make([]any, 0, len(data)) + for i := range data { + switch val := data[i].(type) { + case string: + if val == "" { + continue + } + case fcExprAnd: + if and { + d = append(d, val.Exprs...) + continue + } + case fcExprOr: + if !and { + d = append(d, val.Exprs...) + continue + } + } + d = append(d, data[i]) + } + return d +} + +var defaultFuncCreateKindStrings = [...]string{ + "invalid", "string", "int", "uint", "float", "bool", "any", + "setstring", "setint", "setuint", "setfloat", "setbool", "setany", +} + +// NewFuncCreateKind 函数解析FuncCreateKind字符串。 +func NewFuncCreateKind(s string) FuncCreateKind { + s = strings.ToLower(s) + for i, str := range defaultFuncCreateKindStrings { + if s == str { + return FuncCreateKind(i) + } + } + return FuncCreateInvalid +} + +// NewFuncCreateKindWithType 函数类型创建FuncCreateKind。 +func NewFuncCreateKindWithType(t reflect.Type) FuncCreateKind { + if t == nil { + return FuncCreateInvalid + } + for { + switch t.Kind() { + case reflect.Ptr: + t = t.Elem() + case reflect.Slice, reflect.Array: + t = t.Elem() + case reflect.String: + return FuncCreateString + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return FuncCreateInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return FuncCreateUint + case reflect.Float32, reflect.Float64: + return FuncCreateFloat + case reflect.Bool: + return FuncCreateBool + case reflect.Struct, reflect.Map, reflect.Interface: + return FuncCreateAny + default: + return FuncCreateInvalid + } + } +} + +func (kind FuncCreateKind) String() string { + return defaultFuncCreateKindStrings[kind] +} + +// RunPtr executes any function, dereferences the Ptr type, +// and executes each value of Slice and Array. +// +// RunPtr 执行任意函数,对Ptr类型会解除引用,对Slice和Array每一个值进行执行。 +func (fn *FuncRunner) RunPtr(v reflect.Value) bool { + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + return fn.RunPtr(reflect.Zero(v.Type().Elem())) + } + return fn.RunPtr(v.Elem()) + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + if !fn.RunPtr(v.Index(i)) { + return false + } + } + return true + default: + return fn.Run(v) + } +} + +// Run 执行任意函数。 +func (fn *FuncRunner) Run(v reflect.Value) bool { + switch fn.Kind { + case FuncCreateString: + return fn.Func.(func(string) bool)(v.String()) + case FuncCreateInt: + return fn.Func.(func(int) bool)(int(v.Int())) + case FuncCreateUint: + return fn.Func.(func(uint) bool)(uint(v.Uint())) + case FuncCreateFloat: + return fn.Func.(func(float64) bool)(v.Float()) + case FuncCreateBool: + return fn.Func.(func(bool) bool)(v.Bool()) + case FuncCreateAny: + return fn.Func.(func(any) bool)(v.Interface()) + case FuncCreateSetString: + v.SetString(fn.Func.(func(string) string)(v.String())) + case FuncCreateSetInt: + v.SetInt(int64((fn.Func.(func(int) int)(int(v.Int()))))) + case FuncCreateSetUint: + v.SetUint(uint64(fn.Func.(func(uint) uint)(uint(v.Uint())))) + case FuncCreateSetFloat: + v.SetFloat(fn.Func.(func(float64) float64)(v.Float())) + case FuncCreateSetBool: + v.SetBool(fn.Func.(func(bool) bool)(v.Bool())) + case FuncCreateSetAny: + val := fn.Func.(func(any) any)(v.Interface()) + if val != nil { + v.Set(reflect.ValueOf(val)) + } else { + v.Set(reflect.Zero(v.Type())) + } + } + return true +} diff --git a/funcdefine.go b/funcdefine.go new file mode 100644 index 0000000..69b0ca3 --- /dev/null +++ b/funcdefine.go @@ -0,0 +1,562 @@ +package eudore + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" + "time" + "unicode" +) + +func loadDefaultFuncDefine(fc FuncCreator) { + f := fc.RegisterFunc + // func(T) bool + _ = f("zero", fbZero[string], fbZero[int], fbZero[uint], fbZero[float64], fbZero[bool], fcAnyZero) + _ = f("nozero", fbNozero[string], fbNozero[int], fbNozero[uint], fbNozero[float64], fbNozero[bool], fcAnyNozero) + _ = f("min", fbMin[int], fbMin[uint], fbMin[float64], fbStringMin) // min=1 + _ = f("max", fbMax[int], fbMax[uint], fbMax[float64], fbStringMax) // max=1 + _ = f("equal", fbEqual[string], fbEqual[int], fbEqual[uint]) // equal=string equal!=string + _ = f("enum", fbEnum[int], fbEnum[uint], fbEnum[string]) // enum=1,2,3 enum!=1,2,3 + _ = f("len", fbStringLen, fbAnyLen) // len=3 len<3 len>3 + _ = f("num", fbStringNum) + _ = f("integer", fbStringInteger) + _ = f("domain", fbStringDomain) + _ = f("mail", fbStringMail) + _ = f("phone", fbStringPhone) + _ = f("regexp", fbStringRegexp) + _ = f("patten", fbStringpPatten) + _ = f("prefix", fbStringFuncBool(strings.HasPrefix)) // prefix=string prefix!=string + _ = f("suffix", fbStringFuncBool(strings.HasSuffix)) // suffix=string + _ = f("contains", fbStringFuncBool(strings.Contains)) // contains=string + _ = f("fold", fbStringFuncBool(strings.EqualFold)) // fold=string + _ = f("count", fbStringFuncInt(strings.Count)) // count=number,string countnumber,string + _ = f("compare", fbStringFuncInt(strings.Compare)) // compare=number,string + _ = f("index", fbStringFuncInt(strings.Index)) // index=number,string + _ = f("lastindex", fbStringFuncInt(strings.LastIndex)) // lastindex=number,string + _ = f("after", fbTimeAfter) + _ = f("before", fbTimeBefore) + // func(T) T + _ = f("default", fsDefault[string], fsDefault[int], fsDefault[uint], fsDefault[float64], fsDefault[bool], fsAnyDefault) + _ = f("value", fsValue[string], fsValue[int], fsValue[uint], fsValue[float64], fsValue[bool], fsTimeValue) + _ = f("add", fsAdd[int], fsAdd[uint], fsAdd[float64], fsTimeAdd) + _ = f("now", fsStringNow, fsTimeNow) + _ = f("len", fsStringLen) + _ = f("md5", fsStringMd5) + _ = f("tolower", strings.ToLower) // tolower + _ = f("toupper", strings.ToUpper) // toupper + _ = f("totitle", strings.Title) //nolint:staticcheck + _ = f("replace", fsStringReplace) // replace=old,new replace=-1,old,new + _ = f("trimspace", strings.TrimSpace) // trimspace + _ = f("trim", fsStringFuncString(strings.Trim)) // trim=string + _ = f("trimprefix", fsStringFuncString(strings.TrimPrefix)) // trim=trimprefix + _ = f("trimsuffix", fsStringFuncString(strings.TrimSuffix)) // trim=trimsuffix + _ = f("hide", fsStringHide) + _ = f("hidesurname", fsStringHideSurname) + _ = f("hidename", fsStringHideName) + _ = f("hidemail", fsStringHideMail) + _ = f("hidephone", fsStringHidePhone) + // alias + _ = f("must", fbNozero[string], fbNozero[int], fbNozero[uint], fbNozero[float64], fbNozero[bool], fcAnyNozero) + _ = f("eq", fbEqual[string], fbEqual[int], fbEqual[uint]) +} + +func trimFuncOperate(str string) string { + for _, r := range []byte{'!', '<', '>', '=', ':'} { + if len(str) > 0 && str[0] == r { + str = str[1:] + } + } + return str +} + +func fbZero[T int | uint | float64 | string | bool](i T) bool { + var t T + return i == t +} + +func fcAnyZero(i any) bool { + t, ok := i.(time.Time) + if ok { + return t.IsZero() + } + return i == nil || reflect.ValueOf(i).IsZero() +} + +func fbNozero[T int | uint | float64 | string | bool](i T) bool { + var t T + return i != t +} + +// fcAnyNozero 函数验证一个对象是否为零值,使用reflect.Value.IsZero函数实现。 +func fcAnyNozero(i any) bool { + return !fcAnyZero(i) +} + +// fcgMin 函数生成一个验证value最小值的验证函数。 +func fbMin[T int | uint | float64](s string) (func(T) bool, error) { + val, err := GetAnyByStringWithError[T](trimFuncOperate(s)) + if err != nil { + return nil, err + } + return func(num T) bool { + return num >= val + }, nil +} + +// fbMax 函数生成一个验证value最大值的验证函数。 +func fbMax[T int | uint | float64](s string) (func(T) bool, error) { + val, err := GetAnyByStringWithError[T](trimFuncOperate(s)) + if err != nil { + return nil, err + } + return func(num T) bool { + return num <= val + }, nil +} + +// fbStringMin 函数生成一个验证string最小值的验证函数。 +func fbStringMin(s string) (func(string) bool, error) { + min, err := strconv.ParseInt(trimFuncOperate(s), 10, 32) + if err != nil { + return nil, err + } + intmin := int(min) + return func(arg string) bool { + num, err := strconv.Atoi(arg) + if err != nil { + return false + } + return num >= intmin + }, nil +} + +// fbStringMax 函数生成一个验证string最大值的验证函数。 +func fbStringMax(s string) (func(string) bool, error) { + max, err := strconv.ParseInt(trimFuncOperate(s), 10, 32) + if err != nil { + return nil, err + } + intmax := int(max) + return func(arg string) bool { + num, err := strconv.Atoi(arg) + if err != nil { + return false + } + return num <= intmax + }, nil +} + +func fbEqual[T string | int | uint](s string) (func(T) bool, error) { + b := !strings.HasPrefix(s, "!=") + val, err := GetAnyByStringWithError[T](trimFuncOperate(s)) + if err != nil { + return nil, err + } + return func(arg T) bool { + return val == arg == b + }, nil +} + +func fbEnum[T string | int | uint](s string) (func(arg T) bool, error) { + b := !strings.HasPrefix(s, "!=") + strs := strings.Split(trimFuncOperate(s), ",") + values := make([]T, len(strs)) + for i := range strs { + val, err := GetAnyByStringWithError[T](strs[i]) + if err != nil { + return nil, err + } + values[i] = val + } + + if len(strs) < 9 { + return func(arg T) bool { + for _, val := range values { + if val == arg { + return b + } + } + return !b + }, nil + } + + macths := make(map[T]struct{}, len(strs)) + for i := range values { + macths[values[i]] = struct{}{} + } + return func(arg T) bool { + _, has := macths[arg] + return has == b + }, nil +} + +func integerCompare(op byte, a, b int) bool { + switch op { + case '>': + return a > b + case '<': + return a < b + case '!': + return a != b + default: + return a == b + } +} + +// fbStringLen 函数生一个验证字符串长度'>','<','='指定长度的验证函数。 +func fbStringLen(s string) (func(s string) bool, error) { + length, err := strconv.ParseInt(trimFuncOperate(s), 10, 32) + if err != nil { + return nil, err + } + + f, l := s[0], int(length) + return func(s string) bool { + return integerCompare(f, len(s), l) + }, nil +} + +// fbAnyLen 函数生一个验证字符串长度'>','<','='指定长度的验证函数。 +func fbAnyLen(s string) (func(i any) bool, error) { + length, err := strconv.ParseInt(trimFuncOperate(s), 10, 32) + if err != nil { + return nil, err + } + f, l := s[0], int(length) + return func(i any) bool { + v := reflect.Indirect(reflect.ValueOf(i)) + switch v.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + return integerCompare(f, v.Len(), l) + default: + return false + } + }, nil +} + +// fbStringNum 函数验证一个字符串是否为float64。 +func fbStringNum(s string) bool { + _, err := strconv.ParseFloat(s, 64) + return err == nil +} + +func fbStringInteger(s string) bool { + for _, b := range s { + if b < '0' || b > '9' { + return false + } + } + return true +} + +func fbStringDomain(s string) bool { + pos := strings.LastIndexByte(s, '.') + first := strings.IndexByte(s, '.') + return first > 0 && pos > 0 && pos != len(s)-1 +} + +func fbStringMail(s string) bool { + pos := strings.IndexByte(s, '@') + if pos > 0 && pos != len(s)-1 { + return fbStringDomain(s[pos+1:]) + } + return false +} + +func fbStringPhone(s string) bool { + s = strings.Replace(s, " ", "", 5) + // 中国大陆移动电话 运营商/归属地/客户号码 1xx/xxxx/xxxx + if len(s) == 11 && s[0] == '1' && fbStringInteger(s) { + return true + } + // 国际电话号码格式 国际冠码/国际电话区号/电话号码 国际冠码00或+ + if len(s) > 9 && len(s) < 21 && (s[0] == '+' || strings.HasPrefix(s, "00")) && + fbStringInteger(strings.ReplaceAll(s[1:], "-", "")) { + return true + } + // 中国大陆固定电话 长途冠码/省市区号/电话号码 长途冠码0 省市区号2位或3位 电话号码7位或8位 + if len(s) > 9 && len(s) < 13 && s[0] == '0' { + pos := strings.IndexByte(s, '-') + if pos == 2 || pos == 3 { + return fbStringInteger(s[:pos]) && fbStringInteger(s[pos+1:]) + } + } + return false +} + +// fbStringpPatten 模式匹配对象,允许使用带'*'的模式。 +func fbStringpPatten(str string) (func(string) bool, error) { + b := !strings.HasPrefix(str, "!=") + patten := trimFuncOperate(str) + return func(obj string) bool { + parts := strings.Split(patten, "*") + if len(parts) < 2 { + return patten == obj == b + } + if !strings.HasPrefix(obj, parts[0]) { + return !b + } + for _, i := range parts { + if i == "" { + continue + } + pos := strings.Index(obj, i) + if pos == -1 { + return !b + } + obj = obj[pos+len(i):] + } + return b + }, nil +} + +// fbStringRegexp 函数生成一个正则检测字符串的验证函数。 +func fbStringRegexp(s string) (func(arg string) bool, error) { + b := !strings.HasPrefix(s, "!=") + re, err := regexp.Compile(trimFuncOperate(s)) + if err != nil { + return nil, err + } + // 返回正则匹配校验函数 + return func(arg string) bool { + return re.MatchString(arg) == b + }, nil +} + +func fbStringFuncBool(fn func(string, string) bool) func(string) (func(string) bool, error) { + return func(s string) (func(string) bool, error) { + b := !strings.HasPrefix(s, "!=") + s = trimFuncOperate(s) + return func(str string) bool { + return fn(str, s) == b + }, nil + } +} + +func fbStringFuncInt(fn func(string, string) int) func(string) (func(string) bool, error) { + return func(s string) (func(string) bool, error) { + num, str, ok := strings.Cut(trimFuncOperate(s), ",") + if !ok || num == "" || str == "" { + return nil, fmt.Errorf("funcCreator setstring format must 'name=num,string', current: %s", s) + } + + n, err := GetAnyByStringWithError[int](num) + if err != nil { + return nil, err + } + + f := s[0] + return func(s string) bool { + return integerCompare(f, fn(s, str), n) + }, nil + } +} + +func fbTimeAfter(str string) (func(any) bool, error) { + t, err := GetAnyByStringWithError[time.Time](trimFuncOperate(str)) + if err != nil { + return nil, err + } + return func(i any) bool { + t2, ok := i.(time.Time) + if ok { + return t2.After(t) + } + return false + }, nil +} + +func fbTimeBefore(str string) (func(any) bool, error) { + t, err := GetAnyByStringWithError[time.Time](trimFuncOperate(str)) + if err != nil { + return nil, err + } + return func(i any) bool { + t2, ok := i.(time.Time) + if ok { + return t2.Before(t) + } + return false + }, nil +} + +func fsDefault[T string | int | uint | float64 | bool](T) T { + var t T + return t +} + +func fsAnyDefault(i any) any { + _, ok := i.(time.Time) + if ok { + return time.Time{} + } + return nil +} + +func fsValue[T string | int | uint | float64 | bool](s string) (func(T) T, error) { + val, err := GetAnyByStringWithError[T](trimFuncOperate(s)) + if err != nil { + return nil, err + } + return func(arg T) T { + return val + }, nil +} + +func fsTimeValue(s string) (func(i any) any, error) { + t, err := GetAnyByStringWithError[time.Time](trimFuncOperate(s)) + if err != nil { + return nil, err + } + return func(i any) any { + return t + }, nil +} + +func fsAdd[T string | int | uint | float64](s string) (func(T) T, error) { + val, err := GetAnyByStringWithError[T](trimFuncOperate(s)) + if err != nil { + return nil, err + } + return func(arg T) T { + return arg + val + }, nil +} + +func fsTimeAdd(s string) (func(i any) any, error) { + d, err := time.ParseDuration(trimFuncOperate(s)) + if err != nil { + return nil, err + } + return func(i any) any { + t, ok := i.(time.Time) + if ok { + return t.Add(d) + } + return i + }, nil +} + +func fsStringNow(str string) (func(string) string, error) { + f := trimFuncOperate(str) + if f == time.Now().Format(f) { + return fsValue[string](str) + } + return func(string) string { + return time.Now().Format(f) + }, nil +} + +func fsTimeNow(any) any { + return time.Now() +} + +func fsStringLen(str string) string { + return strconv.Itoa(len(str)) +} + +func fsStringMd5(str string) string { + h := md5.New() + h.Write([]byte(str)) + return hex.EncodeToString(h.Sum(nil)) +} + +func fsStringReplace(s string) (func(string) string, error) { + o, n, _ := strings.Cut(trimFuncOperate(s), ",") + num, err := GetAnyByStringWithError(o, -1) + if err == nil { + o, n, _ = strings.Cut(n, ",") + } + return func(str string) string { + return strings.Replace(str, o, n, num) + }, nil +} + +func fsStringFuncString(fn func(string, string) string) func(string) (func(string) string, error) { + return func(s string) (func(string) string, error) { + s = trimFuncOperate(s) + return func(str string) string { + return fn(str, s) + }, nil + } +} + +func fsStringHide(string) string { + return "***" +} + +func nameHasChinese(str string) bool { + for _, v := range str { + if unicode.Is(unicode.Han, v) { + return true + } + } + return false +} + +func fsStringHideSurname(s string) string { + if len(s) < 3 { + return "****" + } + if nameHasChinese(s) { + return "*" + string([]rune(s)[1:]) + } + _, n, ok := strings.Cut(s, " ") + if ok { + return "**** " + n + } + return "****" + s[len(s)-2:] +} + +func fsStringHideName(s string) string { + if len(s) < 3 { + return "****" + } + if nameHasChinese(s) { + return string([]rune(s)[0]) + "**" + } + sur, _, ok := strings.Cut(s, " ") + if ok { + return sur + " ****" + } + return sur[0:2] + "****" +} + +func fsStringHideMail(s string) string { + name, domain, ok := strings.Cut(s, "@") + if ok { + if len(name) > 8 { + return name[:3] + "****" + "@" + domain + } + if len(name) > 4 { + return name[:2] + "****" + "@" + domain + } + return "****@" + domain + } + return s +} + +func fsStringHidePhone(phone string) string { + s := strings.Replace(phone, " ", "", 5) + // China 中国大陆移动电话 + if len(s) == 11 && s[0] == '1' && fbStringInteger(s) { + return phone[:len(phone)-8] + "****" + phone[len(phone)-4:] + } + // 国际电话号码格式 国际冠码/国际电话区号/电话号码 国际冠码00或+ + if len(s) > 9 && len(s) < 21 && (s[0] == '+' || strings.HasPrefix(s, "00")) && + fbStringInteger(strings.ReplaceAll(s[1:], "-", "")) { + return phone[:len(phone)-8] + "****" + phone[len(phone)-4:] + } + // 中国大陆固定电话 长途冠码/省市区号/电话号码 长途冠码0 省市区号2位或3位 电话号码7位或8位 + if len(s) > 9 && len(s) < 13 && s[0] == '0' { + pos := strings.IndexByte(s, '-') + if pos == 2 || pos == 3 { + return phone[:len(phone)-4] + "****" + } + } + return phone +} diff --git a/go.mod b/go.mod index 2033a90..082be5e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/eudore/eudore -go 1.9 +go 1.20 \ No newline at end of file diff --git a/handler.go b/handler.go index dcee6ae..b97cd23 100644 --- a/handler.go +++ b/handler.go @@ -1,418 +1,104 @@ package eudore import ( + "errors" "fmt" + iofs "io/fs" + "math" "net/http" - "path/filepath" + "os" + filepath "path" "reflect" "runtime" - "strings" + "sort" + "strconv" "unsafe" ) -// HandlerFunc 是处理一个Context的函数 +// HandlerFunc is a function that processes a Context. +// +// HandlerFunc 是处理一个Context的函数。 type HandlerFunc func(Context) +// HandlerFuncs is a collection of HandlerFunc, representing multiple request processing functions. +// // HandlerFuncs 是HandlerFunc的集合,表示多个请求处理函数。 type HandlerFuncs []HandlerFunc -// HandlerExtender 定义函数扩展处理者的方法。 -// -// HandlerExtender默认拥有Base、Warp、Tree三种实现,具体参数三种对象的文档。 -type HandlerExtender interface { - RegisterHandlerExtend(string, interface{}) error - NewHandlerFuncs(string, interface{}) HandlerFuncs - ListExtendHandlerNames() []string -} - -// handlerExtendBase 定义基础的函数扩展。 -type handlerExtendBase struct { - ExtendNewType []reflect.Type - ExtendNewFunc []reflect.Value - ExtendInterfaceType []reflect.Type - ExtendInterfaceFunc []reflect.Value -} - -// handlerExtendWarp 定义链式函数扩展。 -type handlerExtendWarp struct { - HandlerExtender - LastExtender HandlerExtender -} - -// handlerExtendTree 定义基于路径匹配的函数扩展。 -type handlerExtendTree struct { - HandlerExtender - path string - childs []*handlerExtendTree -} - -type handlerHTTP interface { - HandleHTTP(Context) -} - -var ( - // contextFuncName key类型一定为HandlerFunc类型,保存函数可能正确的名称。 - contextFuncName = make(map[uintptr]string) // 最终名称 - contextSaveName = make(map[uintptr]string) // 函数名称 - contextAliasName = make(map[uintptr][]string) // 对象名称 - fineLineFieldsKeys = []string{"file", "line"} -) - -// init 函数初始化内置扩展的请求上下文处理函数。 -func init() { - // 路由方法扩展 - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendHandlerHTTP) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendHandlerNetHTTP) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncNetHTTP1) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncNetHTTP2) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFunc) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncRender) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncError) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncRenderError) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncContextError) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncContextRender) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncContextRenderError) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncContextInterfaceError) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncRPCMap) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendHandlerRPC) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncString) - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendHandlerStringer) -} - -// NewHandlerExtendBase method returns a basic function extension processing object. -// -// The NewHandlerExtendBase().RegisterHandlerExtend method registers a conversion function and ignores the path. -// -// The NewHandlerExtendBase().NewHandlerFuncs method implementation creates multiple request handler functions, ignoring paths. -// -// NewHandlerExtendBase 方法返回一个基本的函数扩展处理对象。 -// -// NewHandlerExtendBase().RegisterHandlerExtend 方法实现注册一个转换函数,忽略路径。 -// -// NewHandlerExtendBase().NewHandlerFuncs 方法实现创建多个请求处理函数,忽略路径。 -func NewHandlerExtendBase() HandlerExtender { - return &handlerExtendBase{} -} - -// RegisterHandlerExtend 函数注册一个请求上下文处理转换函数,参数必须是一个函数,该函数的参数必须是一个函数、接口、指针类型之一,返回值必须是返回一个HandlerFunc对象。 -// -// 如果添加多个接口类型转换,注册类型不直接是接口而是实现接口,会按照接口注册顺序依次检测是否实现接口。 -// -// 例如: func(func(...)) HanderFunc, func(http.Handler) HandlerFunc -func (ext *handlerExtendBase) RegisterHandlerExtend(_ string, fn interface{}) error { - iType := reflect.TypeOf(fn) - // RegisterHandlerExtend函数的参数必须是一个函数类型 - if iType.Kind() != reflect.Func { - return ErrRegisterNewHandlerParamNotFunc - } - - // 检查函数参数必须为 func(Type) 或 func(string, Type) ,允许使用的type值定义在DefaultHandlerExtendAllowType。 - if (iType.NumIn() != 1) && (iType.NumIn() != 2 || iType.In(0).Kind() != reflect.String) { - return fmt.Errorf(ErrFormatRegisterHandlerExtendInputParamError, iType.String()) - } - _, ok := DefaultHandlerExtendAllowType[iType.In(iType.NumIn()-1).Kind()] - if !ok { - return fmt.Errorf(ErrFormatRegisterHandlerExtendInputParamError, iType.String()) - } - - // 检查函数返回值必须是HandlerFunc - if iType.NumOut() != 1 || iType.Out(0) != typeHandlerFunc { - return fmt.Errorf(ErrFormatRegisterHandlerExtendOutputParamError, iType.String()) - } - - ext.ExtendNewType = append(ext.ExtendNewType, iType.In(iType.NumIn()-1)) - ext.ExtendNewFunc = append(ext.ExtendNewFunc, reflect.ValueOf(fn)) - if iType.In(iType.NumIn()-1).Kind() == reflect.Interface { - ext.ExtendInterfaceType = append(ext.ExtendInterfaceType, iType.In(iType.NumIn()-1)) - ext.ExtendInterfaceFunc = append(ext.ExtendInterfaceFunc, reflect.ValueOf(fn)) - } - return nil -} - -// NewHandlerFuncs 函数根据参数返回一个HandlerFuncs。 -func (ext *handlerExtendBase) NewHandlerFuncs(path string, i interface{}) HandlerFuncs { - val, ok := i.(reflect.Value) - if !ok { - val = reflect.ValueOf(i) - } - return NewHandlerFuncsFilter(ext.newHandlerFuncs(path, val)) -} - -func (ext *handlerExtendBase) newHandlerFuncs(path string, iValue reflect.Value) HandlerFuncs { - // 基础类型返回 - switch fn := iValue.Interface().(type) { - case func(Context): - SetHandlerFuncName(fn, getHandlerAliasName(iValue)) - return HandlerFuncs{fn} - case HandlerFunc: - SetHandlerFuncName(fn, getHandlerAliasName(iValue)) - return HandlerFuncs{fn} - case []HandlerFunc: - return fn - case HandlerFuncs: - return fn - } - // 尝试转换成HandlerFuncs - fn := ext.newHandlerFunc(path, iValue) - if fn != nil { - return HandlerFuncs{fn} - } - // 解引用数组再转换HandlerFuncs - switch iValue.Type().Kind() { - case reflect.Slice, reflect.Array: - var fns HandlerFuncs - for i := 0; i < iValue.Len(); i++ { - hs := ext.newHandlerFuncs(path, iValue.Index(i)) - if hs != nil { - fns = append(fns, hs...) - } - } - if len(fns) != 0 { - return fns - } - case reflect.Interface, reflect.Ptr: - return ext.newHandlerFuncs(path, iValue.Elem()) - } - return nil -} - -// newHandlerFunc 函数使用一个函数或接口参数转换成请求上下文处理函数。 -// -// 参数必须是一个函数,函数拥有一个参数作为入参,一个HandlerFunc对象作为返回值。 -// -// 先检测对象是否拥有直接注册的类型扩展函数,再检查对象是否实现其中注册的接口类型。 +// HandlerFuncs is a collection of HandlerFunc, representing multiple request processing functions. // -// 允许进行多次注册,只要注册返回值不为空就会返回对应的处理函数。 -func (ext *handlerExtendBase) newHandlerFunc(path string, iValue reflect.Value) HandlerFunc { - iType := iValue.Type() - for i := range ext.ExtendNewType { - if ext.ExtendNewType[i] == iType { - h := ext.createHandlerFunc(path, ext.ExtendNewFunc[i], iValue) - if h != nil { - return h - } - } - } - // 判断是否实现接口类型 - for i, iface := range ext.ExtendInterfaceType { - if iType.Implements(iface) { - h := ext.createHandlerFunc(path, ext.ExtendInterfaceFunc[i], iValue) - if h != nil { - return h - } - } - } - return nil -} - -// createHandlerFunc 函数使用转换函数和对象创建一个HandlerFunc,并保存HandlerFunc的名称和使用的扩展函数名称。 -func (ext *handlerExtendBase) createHandlerFunc(path string, fn, iValue reflect.Value) (h HandlerFunc) { - if fn.Type().NumIn() == 1 { - h = fn.Call([]reflect.Value{iValue})[0].Interface().(HandlerFunc) - } else { - h = fn.Call([]reflect.Value{reflect.ValueOf(path), iValue})[0].Interface().(HandlerFunc) - } - if h == nil { - return nil - } - // 获取扩展名称,eudore包移除包前缀 - extname := runtime.FuncForPC(fn.Pointer()).Name() - if len(extname) > 24 && extname[:25] == "github.com/eudore/eudore." { - extname = extname[25:] - } - // 获取新函数名称,一般来源于函数扩展返回的函数名称。 - hptr := getFuncPointer(reflect.ValueOf(h)) - name := contextSaveName[hptr] - // 使用原值名称 - if name == "" && iValue.Kind() != reflect.Struct { - name = getHandlerAliasName(iValue) - } - // 推断名称 - if name == "" { - iType := iValue.Type() - switch iType.Kind() { - case reflect.Func: - name = runtime.FuncForPC(iValue.Pointer()).Name() - case reflect.Ptr: - iType = iType.Elem() - name = fmt.Sprintf("*%s.%s", iType.PkgPath(), iType.Name()) - case reflect.Struct: - name = fmt.Sprintf("%s.%s", iType.PkgPath(), iType.Name()) - } - } - contextFuncName[hptr] = fmt.Sprintf("%s(%s)", name, extname) - return h -} - -var formarExtendername = "%s(%s)" - -// ListExtendHandlerNames 方法返回全部注册的函数名称。 -func (ext *handlerExtendBase) ListExtendHandlerNames() []string { - names := make([]string, 0, len(ext.ExtendNewFunc)) - for i := range ext.ExtendNewType { - if ext.ExtendNewType[i].Kind() != reflect.Interface { - names = append(names, fmt.Sprintf(formarExtendername, runtime.FuncForPC(ext.ExtendNewFunc[i].Pointer()).Name(), ext.ExtendNewType[i].String())) - } - } - for i, iface := range ext.ExtendInterfaceType { - names = append(names, fmt.Sprintf(formarExtendername, runtime.FuncForPC(ext.ExtendInterfaceFunc[i].Pointer()).Name(), iface.String())) - } - return names +// HandlerEmpty 函数定义一个空的请求上下文处理函数。 +func HandlerEmpty(Context) { + // Do nothing because empty handler does not process entries. } -// NewHandlerExtendWarp function creates a chained HandlerExtender object. -// -// All objects are registered and created using base. If base cannot create a function handler, use last to create a function handler. -// -// NewHandlerExtendWarp 函数创建一个链式HandlerExtender对象。 -// -// The NewHandlerExtendWarp(base, last).RegisterHandlerExtend method uses the base object to register extension functions. -// -// The NewHandlerExtendWarp(base, last).NewHandlerFuncs method first uses the base object to create multiple request processing functions. If it returns nil, it uses the last object to create multiple request processing functions. -// -// 所有对象注册和创建均使用base,如果base无法创建函数处理者则使用last创建函数处理者。 -// -// NewHandlerExtendWarp(base, last).RegisterHandlerExtend 方法使用base对象注册扩展函数。 +// HandlerRouter403 function defines the default 403 processing. // -// NewHandlerExtendWarp(base, last).NewHandlerFuncs 方法先使用base对象创建多个请求处理函数,如果返回nil,则使用last对象创建多个请求处理函数。 -func NewHandlerExtendWarp(base, last HandlerExtender) HandlerExtender { - return &handlerExtendWarp{ - HandlerExtender: base, - LastExtender: last, - } +// HandlerRouter403 函数定义默认403处理。 +func HandlerRouter403(ctx Context) { + const page404 string = "403 forbidden" + ctx.WriteHeader(StatusForbidden) + _ = ctx.Render(page404) } -// The NewHandlerFuncs method implements the NewHandlerFuncs function. If the current HandlerExtender cannot create HandlerFuncs, it calls the superior HandlerExtender to process. +// HandlerRouter404 function defines the default 404 processing. // -// NewHandlerFuncs 方法实现NewHandlerFuncs函数,如果当前HandlerExtender无法创建HandlerFuncs,则调用上级HandlerExtender处理。 -func (ext *handlerExtendWarp) NewHandlerFuncs(path string, i interface{}) HandlerFuncs { - hs := ext.HandlerExtender.NewHandlerFuncs(path, i) - if hs != nil { - return hs - } - return ext.LastExtender.NewHandlerFuncs(path, i) -} - -// ListExtendHandlerNames 方法返回全部注册的函数名称。 -func (ext *handlerExtendWarp) ListExtendHandlerNames() []string { - return append(ext.LastExtender.ListExtendHandlerNames(), ext.HandlerExtender.ListExtendHandlerNames()...) +// HandlerRouter404 函数定义默认404处理。 +func HandlerRouter404(ctx Context) { + const page404 string = "404 page not found" + ctx.WriteHeader(StatusNotFound) + _ = ctx.Render(page404) } -// NewHandlerExtendTree function creates a path-based function extender. -// -// Mainly implement path matching. All actions are processed by the node's HandlerExtender, and the NewHandlerExtendBase () object is used. -// -// All registration and creation actions will be performed by matching the lowest node of the tree. If it cannot be created, the tree nodes will be processed upwards in order. -// -// The NewHandlerExtendTree().RegisterHandlerExtend method registers a handler function based on the path, and initializes to NewHandlerExtendBase () if the HandlerExtender is empty. +// HandlerRouter405 function defines the default 405 processing and returns Allow and X-Match-Route Header. // -// The NewHandlerExtendTree().NewHandlerFuncs method matches the child nodes of the tree based on the path, and then executes the NewHandlerFuncs method from the most child node up. If it returns non-null, it returns directly. -// -// NewHandlerExtendTree 函数创建一个基于路径的函数扩展者。 -// -// 主要实现路径匹配,所有行为使用节点的HandlerExtender处理,使用NewHandlerExtendBase()对象。 -// -// 所有注册和创建行为都会匹配树最下级节点执行,如果无法创建则在树节点依次向上处理。 -// -// NewHandlerExtendTree().RegisterHandlerExtend 方法基于路径注册一个处理函数,如果HandlerExtender为空则初始化为NewHandlerExtendBase()。 -// -// NewHandlerExtendTree().NewHandlerFuncs 方法基于路径向树子节点匹配,后从最子节点依次向上执行NewHandlerFuncs方法,如果返回非空直接返回,否在会依次执行注册行为。 -func NewHandlerExtendTree() HandlerExtender { - return &handlerExtendTree{} +// HandlerRouter405 函数定义默认405处理,返回Allow和X-Match-Route Header。 +func HandlerRouter405(ctx Context) { + const page405 string = "405 method not allowed" + ctx.SetHeader(HeaderAllow, ctx.GetParam(ParamAllow)) + ctx.SetHeader(HeaderXEudoreRoute, ctx.GetParam(ParamRoute)) + ctx.WriteHeader(StatusMethodNotAllowed) + _ = ctx.Render(page405) } -// RegisterHandlerExtend 方法基于路径注册一个扩展函数。 -func (ext *handlerExtendTree) RegisterHandlerExtend(path string, i interface{}) error { - // 匹配当前节点注册 - if path == "" { - if ext.HandlerExtender == nil { - ext.HandlerExtender = NewHandlerExtendBase() +// HandlerMetadata 函数返回从contextKey获取metadata,可以使用路由参数*或name指定key。 +func HandlerMetadata(ctx Context) { + name := GetAnyByString(ctx.GetParam("*"), ctx.GetParam("name")) + if name != "" { + meta := anyMetadata(ctx.Value(NewContextKey(name))) + if meta != nil { + _ = ctx.Render(meta) + } else { + HandlerRouter404(ctx) } - return ext.HandlerExtender.RegisterHandlerExtend("", i) - } - - // 寻找对应的子节点注册 - for pos := range ext.childs { - subStr, find := getSubsetPrefix(path, ext.childs[pos].path) - if find { - if subStr != ext.childs[pos].path { - ext.childs[pos].path = strings.TrimPrefix(ext.childs[pos].path, subStr) - ext.childs[pos] = &handlerExtendTree{ - path: subStr, - childs: []*handlerExtendTree{ext.childs[pos]}, - } - } - return ext.childs[pos].RegisterHandlerExtend(strings.TrimPrefix(path, subStr), i) - } - } - - // 追加一个新的子节点 - newnode := &handlerExtendTree{ - path: path, - HandlerExtender: NewHandlerExtendBase(), - } - ext.childs = append(ext.childs, newnode) - return newnode.HandlerExtender.RegisterHandlerExtend(path, i) -} - -// NewHandlerFuncs 函数基于路径创建多个对象处理函数。 -// -// 递归依次寻找子节点,然后返回时创建多个对象处理函数,如果子节点返回不为空就直接返回。 -func (ext *handlerExtendTree) NewHandlerFuncs(path string, data interface{}) HandlerFuncs { - for _, child := range ext.childs { - if strings.HasPrefix(path, child.path) { - hs := child.NewHandlerFuncs(path[len(child.path):], data) - if hs != nil { - return hs - } - break - } - } - - if ext.HandlerExtender != nil { - return ext.HandlerExtender.NewHandlerFuncs(path, data) + return } - return nil -} -// listExtendHandlerNamesByPrefix 方法递归添加路径前缀返回扩展函数名称。 -func (ext *handlerExtendTree) listExtendHandlerNamesByPrefix(prefix string) []string { - prefix += ext.path - var names []string - if ext.HandlerExtender != nil { - names = ext.HandlerExtender.ListExtendHandlerNames() - if prefix != "" { - for i := range names { - names[i] = prefix + " " + names[i] - } + keys := ctx.Value(ContextKeyAppKeys).([]any) + metas := make(map[string]any, len(keys)) + for i := range keys { + meta := anyMetadata(ctx.Value(keys[i])) + if meta != nil { + metas[fmt.Sprint(keys[i])] = meta } } - - for i := range ext.childs { - names = append(names, ext.childs[i].listExtendHandlerNamesByPrefix(prefix)...) - } - return names -} - -// ListExtendHandlerNames 方法返回全部注册的函数名称。 -func (ext *handlerExtendTree) ListExtendHandlerNames() []string { - return ext.listExtendHandlerNamesByPrefix("") + _ = ctx.Render(metas) } // NewHandlerFuncsFilter 函数过滤掉多个请求上下文处理函数中的空对象。 func NewHandlerFuncsFilter(hs HandlerFuncs) HandlerFuncs { - var num int + var size int for _, h := range hs { if h != nil { - num++ + size++ } } - if num == len(hs) { - return hs + if size == len(hs) { + return hs[:size:size] } // 返回新过滤空的处理函数。 - nhs := make(HandlerFuncs, 0, num) + nhs := make(HandlerFuncs, 0, size) for _, h := range hs { if h != nil { nhs = append(nhs, h) @@ -431,17 +117,17 @@ func NewHandlerFuncsFilter(hs HandlerFuncs) HandlerFuncs { func NewHandlerFuncsCombine(hs1, hs2 HandlerFuncs) HandlerFuncs { // if nil if len(hs1) == 0 { - return hs2 + return hs2[:len(hs2):len(hs2)] } if len(hs2) == 0 { - return hs1 + return hs1[:len(hs1):len(hs1)] } // combine - finalSize := len(hs1) + len(hs2) - if finalSize >= 127 { - panic("HandlerFuncsCombine: too many handlers") + size := len(hs1) + len(hs2) + if size >= DefaultContextMaxHandler { + panic(fmt.Errorf("HandlerFuncsCombine: too many handlers %d", size)) } - hs := make(HandlerFuncs, finalSize) + hs := make(HandlerFuncs, size) copy(hs, hs1) copy(hs[len(hs1):], hs2) return hs @@ -454,23 +140,23 @@ type reflectValue struct { } // getFuncPointer 函数获取一个reflect值的地址作为唯一标识id。 -func getFuncPointer(iValue reflect.Value) uintptr { - val := *(*reflectValue)(unsafe.Pointer(&iValue)) +func getFuncPointer(v reflect.Value) uintptr { + val := *(*reflectValue)(unsafe.Pointer(&v)) return val.ptr } // SetHandlerAliasName 函数设置一个函数处理对象原始名称,如果扩展未生成名称,使用此值。 // // 在handlerExtendBase对象和ControllerInjectSingleton函数中使用到,用于传递控制器函数名称。 -func SetHandlerAliasName(i interface{}, name string) { +func SetHandlerAliasName(i any, name string) { if name == "" { return } - iValue, ok := i.(reflect.Value) + v, ok := i.(reflect.Value) if !ok { - iValue = reflect.ValueOf(i) + v = reflect.ValueOf(i) } - val := *(*reflectValue)(unsafe.Pointer(&iValue)) + val := *(*reflectValue)(unsafe.Pointer(&v)) names := contextAliasName[val.ptr] index := int(val.flag >> 10) if len(names) <= index { @@ -482,11 +168,11 @@ func SetHandlerAliasName(i interface{}, name string) { names[index] = name } -func getHandlerAliasName(iValue reflect.Value) string { - val := *(*reflectValue)(unsafe.Pointer(&iValue)) +func getHandlerAliasName(v reflect.Value) string { + val := *(*reflectValue)(unsafe.Pointer(&v)) names := contextAliasName[val.ptr] index := int(val.flag >> 10) - if len(names) > index { + if index < len(names) { return names[index] } return "" @@ -523,279 +209,161 @@ func (h HandlerFunc) String() string { return runtime.FuncForPC(rh.Pointer()).Name() } -// NewExtendHandlerHTTP 函数handlerHTTP接口转换成HandlerFunc。 -func NewExtendHandlerHTTP(h handlerHTTP) HandlerFunc { - return h.HandleHTTP -} - -// NewExtendHandlerNetHTTP 函数转换处理http.Handler对象。 -func NewExtendHandlerNetHTTP(h http.Handler) HandlerFunc { - clone, ok := h.(interface{ CloneHandler() http.Handler }) - if ok { - h = clone.CloneHandler() - } - return func(ctx Context) { - h.ServeHTTP(ctx.Response(), ctx.Request()) - } +// NewHandlerStatic 函数使用多个any值创建混合静态文件处理函数。 +func NewHandlerStatic(dirs ...any) HandlerFunc { + return NewHandlerHTTPFileSystem(NewFileSystems(dirs...)) } -// NewExtendFuncNetHTTP1 函数转换处理func(http.ResponseWriter, *http.Request)类型。 -func NewExtendFuncNetHTTP1(fn func(http.ResponseWriter, *http.Request)) HandlerFunc { - return func(ctx Context) { - fn(ctx.Response(), ctx.Request()) - } +// NewHandlerEmbed 函数创建iofs.FS扩展函数。 +func NewHandlerEmbed(fs iofs.FS) HandlerFunc { + return NewHandlerHTTPFileSystem(NewFileSystems(fs)) } -// NewExtendFuncNetHTTP2 函数转换处理http.HandlerFunc类型。 -func NewExtendFuncNetHTTP2(fn http.HandlerFunc) HandlerFunc { +// NewHandlerHTTPFileSystem 函数创建http.FileSystem扩展函数。 +func NewHandlerHTTPFileSystem(fs http.FileSystem) HandlerFunc { return func(ctx Context) { - fn(ctx.Response(), ctx.Request()) - } -} - -// NewExtendFunc 函数处理func()。 -func NewExtendFunc(fn func()) HandlerFunc { - return func(Context) { - fn() - } -} - -func getFileLineFieldsVals(iValue reflect.Value) []interface{} { - file, line := runtime.FuncForPC(iValue.Pointer()).FileLine(1) - return []interface{}{file, line} -} - -// NewExtendFuncRender 函数处理func() interface{}。 -func NewExtendFuncRender(fn func() interface{}) HandlerFunc { - fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) - return func(ctx Context) { - data := fn() - if ctx.Response().Size() == 0 { - err := ctx.Render(data) - if err != nil { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) - } + path := filepath.Join(ctx.GetParam(ParamPrefix), ctx.GetParam("*")) + if path == "" { + path = "." } - } -} - -// NewExtendFuncError 函数处理func() error返回的error处理。 -func NewExtendFuncError(fn func() error) HandlerFunc { - fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) - return func(ctx Context) { - err := fn() + file, err := fs.Open(path) if err != nil { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + if errors.Is(err, os.ErrNotExist) { + HandlerRouter404(ctx) + } else if errors.Is(err, os.ErrPermission) { + HandlerRouter403(ctx) + } + ctx.Error(err) + return } - } -} + defer file.Close() -// NewExtendFuncRenderError 函数处理func() (interface{}, error)返回数据渲染和error处理。 -func NewExtendFuncRenderError(fn func() (interface{}, error)) HandlerFunc { - fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) - return func(ctx Context) { - data, err := fn() - if err == nil && ctx.Response().Size() == 0 { - err = ctx.Render(data) + stat, _ := file.Stat() + // embed.FS的ModTime()为空无法使用缓存,设置为默认时间使用304缓存机制。 + modtime := stat.ModTime() + if modtime.IsZero() { + modtime = DefaultHandlerEmbedTime } - if err != nil { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) - } - } -} -// NewExtendFuncContextError 函数处理func(Context) error返回的error处理。 -func NewExtendFuncContextError(fn func(Context) error) HandlerFunc { - fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) - return func(ctx Context) { - err := fn(ctx) - if err != nil { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + switch { + case !stat.IsDir(): + if ctx.Response().Header().Get(HeaderCacheControl) == "" { + ctx.SetHeader(HeaderCacheControl, DefaultHandlerEmbedCacheControl) + } + http.ServeContent(ctx.Response(), ctx.Request(), stat.Name(), modtime, file) + case GetAnyByString[bool](ctx.GetParam(ParamAutoIndex)): + ctx.SetHeader(HeaderCacheControl, "no-cache") + ctx.SetHeader(HeaderLastModified, modtime.UTC().Format(http.TimeFormat)) + handlerStaticDirs(ctx, "/"+ctx.GetParam("*"), file) + default: + ctx.WriteHeader(StatusNotFound) } } } -// NewExtendFuncContextRender 函数处理func(Context) interface{}返回数据渲染。 -func NewExtendFuncContextRender(fn func(Context) interface{}) HandlerFunc { - fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) - return func(ctx Context) { - data := fn(ctx) - if ctx.Response().Size() == 0 { - err := ctx.Render(data) - if err != nil { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) - } - } +func handlerStaticDirs(ctx Context, path string, file http.File) { + files, err := file.Readdir(-1) + if err != nil { + ctx.Fatal(err) + return } -} -// NewExtendFuncContextRenderError 函数处理func(Context) (interface{}, error)返回数据渲染和error处理。 -func NewExtendFuncContextRenderError(fn func(Context) (interface{}, error)) HandlerFunc { - fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) - return func(ctx Context) { - data, err := fn(ctx) - if err == nil && ctx.Response().Size() == 0 { - err = ctx.Render(data) + type fileInfo struct { + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` + Size int64 `alias:"size" json:"size" xml:"size" yaml:"size"` + SizeFormat string `alias:"sizeformat" json:"sizeformat" xml:"sizeformat" yaml:"sizeformat"` + ModTime string `alias:"modtime" json:"modtime" xml:"modtime" yaml:"modtime"` + UnixTime int64 `alias:"unixtime" json:"unixtime" xml:"unixtime" yaml:"unixtime"` + IsDir bool `alias:"isdir" json:"isdir" xml:"isdir" yaml:"isdir"` + } + infos := make([]fileInfo, len(files)) + for i := range files { + infos[i] = fileInfo{ + Name: files[i].Name(), + Size: files[i].Size(), + SizeFormat: formatSize(files[i].Size()), + ModTime: files[i].ModTime().Format("1/2/06, 3:04:05 PM"), + UnixTime: files[i].ModTime().Unix(), + IsDir: files[i].IsDir(), } - if err != nil { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + sort.Slice(infos, func(i, j int) bool { + if infos[i].IsDir == infos[j].IsDir { + return infos[i].Name < infos[j].Name } + return infos[i].IsDir + }) + + if ctx.GetParam(ParamTemplate) == "" { + ctx.SetParam(ParamTemplate, DefaultTemplateNameStaticIndex) } + _ = ctx.Render(struct { + Path string + Files []fileInfo + Upload bool + }{path, infos, GetAnyByString[bool](ctx.GetParam("upload"))}) } -// NewExtendFuncContextInterfaceError 函数处理func(Context) (T, error)返回数据渲染和error处理。 -func NewExtendFuncContextInterfaceError(fn interface{}) HandlerFunc { - iValue := reflect.ValueOf(fn) - iType := iValue.Type() - if iType.Kind() != reflect.Func || iType.NumIn() != 1 || iType.NumOut() != 2 || iType.In(0) != typeContext || iType.Out(1) != typeError { - return nil +func formatSize(n int64) string { + if n < 1024 { + return strconv.FormatInt(n, 10) + " B" } - - fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) - return func(ctx Context) { - vals := iValue.Call([]reflect.Value{reflect.ValueOf(ctx)}) - err, _ := vals[1].Interface().(error) - if err == nil && ctx.Response().Size() == 0 { - err = ctx.Render(vals[0].Interface()) - } - if err != nil { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) - } + sizes := []string{"B", "KB", "MB", "GB", "TB", "PB", "EB"} + e := math.Floor(math.Log(float64(n)) / math.Log(1024)) + v := float64(n) / math.Pow(2, e*10) + if v < 100 { + return fmt.Sprintf("%.1f %s", v, sizes[int(e)]) } + return fmt.Sprintf("%.0f %s", v, sizes[int(e)]) } -// NewExtendFuncRPCMap defines a fixed request and response to function processing of type map [string] interface {}. -// -// is a subset of NewRPCHandlerFunc and has type restrictions, but using map [string] interface {} to save requests does not use reflection. -// -// NewExtendFuncRPCMap 定义了固定请求和响应为map[string]interface{}类型的函数处理。 +// Combine multiple http.FileSystem // -// 是NewRPCHandlerFunc的一种子集,拥有类型限制,但是使用map[string]interface{}保存请求没用使用反射。 -func NewExtendFuncRPCMap(fn func(Context, map[string]interface{}) (interface{}, error)) HandlerFunc { - fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) - return func(ctx Context) { - req := make(map[string]interface{}) - err := ctx.Bind(&req) - if err != nil { - ctx.Fatal(err) - return - } - resp, err := fn(ctx, req) - if err == nil && ctx.Response().Size() == 0 { - err = ctx.Render(resp) - } - if err != nil { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) - } - } -} +// 组合多个http.FileSystem。 +type fileSystems []http.FileSystem -// NewExtendHandlerRPC function needs to pass in a function that returns a request for processing and is dynamically called by reflection. -// -// Function form: func (Context, Request) (Response, error) +// The NewFileSystems function creates a hybrid http.FileSystem object that returns the first file from multiple http.FileSystems. // -// The types of Request and Response can be map or struct or pointer to struct. All 4 parameters need to exist, and the order cannot be changed. +// If the type is string and the path exists, it will be converted to http.Dir; +// If the type is embed.FS converted to http.FS; +// If the type is http.FileSystem, add it directly. // -// NewExtendHandlerRPC 函数需要传入一个函数,返回一个请求处理,通过反射来动态调用。 +// NewFileSystems 函数创建一个混合http.FileSystem对象,从多个http.FileSystem返回首个文件。 // -// 函数形式: func(Context, Request) (Response, error) -// -// Request和Response的类型可以为map或结构体或者结构体的指针,4个参数需要全部存在,且不可调换顺序。 -func NewExtendHandlerRPC(fn interface{}) HandlerFunc { - iType := reflect.TypeOf(fn) - iValue := reflect.ValueOf(fn) - if iType.Kind() != reflect.Func { - return nil - } - if iType.NumIn() != 2 || iType.In(0) != typeContext { - return nil - } - if iType.NumOut() != 2 || iType.Out(1) != typeError { - return nil - } - var typeIn = iType.In(1) - var kindIn = typeIn.Kind() - var typenew = iType.In(1) - // 检查请求类型 - switch typeIn.Kind() { - case reflect.Ptr, reflect.Map, reflect.Slice, reflect.Struct: - default: - return nil - } - if typenew.Kind() == reflect.Ptr { - typenew = typenew.Elem() - } - - fineLineFieldsVals := getFileLineFieldsVals(iValue) - return func(ctx Context) { - // 创建请求参数并初始化 - req := reflect.New(typenew) - err := ctx.Bind(req.Interface()) - if err != nil { - ctx.Fatal(err) - return - } - if kindIn != reflect.Ptr { - req = req.Elem() - } - - // 反射调用执行函数。 - vals := iValue.Call([]reflect.Value{reflect.ValueOf(ctx), req}) - - // 检查函数执行err。 - err, ok := vals[1].Interface().(error) - if ok { - ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) - return - } - - // 渲染返回的数据。 - err = ctx.Render(vals[0].Interface()) - if err != nil { - ctx.Fatal(err) +// 如果类型为string且路径存在将转换成http.Dir; +// 如果类型为embed.FS转换成http.FS; +// 如果类型为http.FileSystem直接追加。 +func NewFileSystems(dirs ...any) http.FileSystem { + var fs fileSystems + for i := range dirs { + switch dir := dirs[i].(type) { + case string: + _, err := os.Stat(dir) + if err == nil { + fs = append(fs, http.Dir(dir)) + } + case iofs.FS: + fs = append(fs, http.FS(dir)) + case fileSystems: + fs = append(fs, dir...) + case http.FileSystem: + fs = append(fs, dir) } } -} - -// NewExtendHandlerStringer 函数处理fmt.Stringer接口类型转换成HandlerFunc。 -func NewExtendHandlerStringer(fn fmt.Stringer) HandlerFunc { - return func(ctx Context) { - ctx.WriteString(fn.String()) - } -} - -// NewExtendFuncString 函数处理func() string,然后指定函数生成的字符串。 -func NewExtendFuncString(fn func() string) HandlerFunc { - return func(ctx Context) { - ctx.WriteString(fn()) + if len(fs) == 1 { + return fs[0] } + return fs } -// NewStaticHandler 函数更加目标创建一个静态文件处理函数。 -// -// 参数dir指导打开文件的根目录,默认未"." -// -// 路由规则可以指导path参数为请求文件路径,例如/static/*path,将会去打开path参数路径的文件,否在使用ctx.Path(). -func NewStaticHandler(name, dir string) HandlerFunc { - if name == "" { - name = "*" - } - if dir == "" { - dir = "." - } - return func(ctx Context) { - path := ctx.GetParam(name) - if path == "" { - path = ctx.Path() +// Open 方法从多个http.FileSystem返回首个文件。 +func (fs fileSystems) Open(name string) (file http.File, err error) { + err = os.ErrNotExist + for _, f := range fs { + file, err = f.Open(name) + if err == nil { + return file, nil } - if ctx.Request().Header.Get(HeaderCacheControl) == "" { - ctx.SetHeader(HeaderCacheControl, "no-cache") - } - ctx.WriteFile(filepath.Join(dir, filepath.Clean("/"+path))) } -} - -// HandlerEmpty 函数定义一个空的请求上下文处理函数。 -func HandlerEmpty(Context) { - // Do nothing because empty handler does not process entries. + return } diff --git a/handlerdata.go b/handlerdata.go index b7945ca..ff6477f 100644 --- a/handlerdata.go +++ b/handlerdata.go @@ -1,32 +1,38 @@ package eudore import ( - "context" + "bytes" "encoding/json" "encoding/xml" + "errors" "fmt" "html/template" "net/http" "reflect" - "regexp" - "strconv" "strings" - "sync" ) -// HandlerDataFunc 定义请求上下文数据出来函数。 +// HandlerDataFunc defines the request context data processing function. // -// 默认定义Bind Validate Filte Render四种行为。 +// Define four behaviors of Bind Validater Filter Render by default. // -// Binder对象用于请求数据反序列化,默认根据http请求的Content-Type header指定的请求数据格式来解析数据。 +// Binder object is used to request data deserialization, +// By default, data is parsed according to the request data format specified +// by the Content-Type header of the http request. +// +// The Renderer object accepts the header to select the data object serialization method. +// HandlerDataFunc 定义请求上下文数据处理函数。 +// +// 默认定义Bind Validater Filter Render四种行为。 +// +// Binder对象用于请求数据反序列化, +// 默认根据http请求的Content-Type header指定的请求数据格式来解析数据。 // // Renderer对象更加Accept Header选择数据对象序列化的方法。 -type HandlerDataFunc = func(Context, interface{}) error - -func init() { - DefaultRenderHTMLTemplate, _ = template.New("render").Parse(renderHTMLTempdate) -} +type HandlerDataFunc = func(Context, any) error +// The NewBinds method defines the ContentType Header mapping Bind function. +// // NewBinds 方法定义ContentType Header映射Bind函数。 func NewBinds(binds map[string]HandlerDataFunc) HandlerDataFunc { if binds == nil { @@ -43,9 +49,9 @@ func NewBinds(binds map[string]HandlerDataFunc) HandlerDataFunc { mimes += ", " + k } mimes = strings.TrimPrefix(mimes, ", ") - return func(ctx Context, i interface{}) error { + return func(ctx Context, i any) error { contentType := ctx.GetHeader(HeaderContentType) - if ctx.Method() == MethodGet || ctx.Method() == MethodHead || contentType == "" { + if contentType == "" { return BindURL(ctx, i) } fn, ok := binds[strings.SplitN(contentType, ";", 2)[0]] @@ -63,81 +69,135 @@ func NewBinds(binds map[string]HandlerDataFunc) HandlerDataFunc { } } -// NewBindWithHeader 实现Binder额外封装bind header。 +// NewBindWithHeader implements Bind to additionally encapsulate bind header. +// +// NewBindWithHeader 实现Bind额外封装bind header。 func NewBindWithHeader(fn HandlerDataFunc) HandlerDataFunc { - return func(ctx Context, i interface{}) error { + return func(ctx Context, i any) error { BindHeader(ctx, i) return fn(ctx, i) } } -// NewBindWithURL 实现Binder在非get和head方法下实现BindURL。 +// NewBindWithURL implements Bind and also executes BindURL +// when HeaderContentType is not empty. +// +// NewBindWithURL 实现Bind在HeaderContentType非空时也执行BindURL。 func NewBindWithURL(fn HandlerDataFunc) HandlerDataFunc { - return func(ctx Context, i interface{}) error { - if ctx.Method() != MethodGet && ctx.Method() != MethodHead { + return func(ctx Context, i any) error { + if ctx.GetHeader(HeaderContentType) != "" { BindURL(ctx, i) } return fn(ctx, i) } } +func bindMaps[T any](data map[string][]T, i any, tags []string) error { + for key, vals := range data { + for _, val := range vals { + SetAnyByPathWithTag(i, key, val, tags, false) + } + } + return nil +} + +// The BindURL function uses the url parameter to parse the binding body. +// // BindURL 函数使用url参数解析绑定body。 -func BindURL(ctx Context, i interface{}) error { - return ConvertToWithTags(ctx.Querys(), i, DefaultBindURLTags) +func BindURL(ctx Context, i any) error { + return bindMaps(ctx.Querys(), i, DefaultHandlerBindURLTags) } +// The BindForm function uses form to parse and bind the body. +// // BindForm 函数使用form解析绑定body。 -func BindForm(ctx Context, i interface{}) error { - ConvertToWithTags(ctx.FormFiles(), i, DefaultBindFormTags) - return ConvertToWithTags(ctx.FormValues(), i, DefaultBindFormTags) +func BindForm(ctx Context, i any) error { + bindMaps(ctx.FormFiles(), i, DefaultHandlerBindFormTags) + return bindMaps(ctx.FormValues(), i, DefaultHandlerBindFormTags) } -// BindJSON 函数使用json解析绑定body。 -func BindJSON(ctx Context, i interface{}) error { +// The BindJSON function uses encoding/json to parse and bind the body. +// +// BindJSON 函数使用encoding/json解析绑定body。 +func BindJSON(ctx Context, i any) error { return json.NewDecoder(ctx).Decode(i) } -// BindXML 函数使用xml解析绑定body。 -func BindXML(ctx Context, i interface{}) error { +// The BindXML function uses encoding/xml to parse the bound body. +// +// BindXML 函数使用encoding/xml解析绑定body。 +func BindXML(ctx Context, i any) error { return xml.NewDecoder(ctx).Decode(i) } +// The BindProtobuf function uses the built-in protobu to parse the bound body. +// // BindProtobuf 函数使用内置protobu解析绑定body。 -func BindProtobuf(ctx Context, i interface{}) error { +func BindProtobuf(ctx Context, i any) error { return NewProtobufDecoder(ctx).Decode(i) } +// The BindHeader function implements binding using header data. +// +// The header name prefix must be 'X-', example: X-Euduore-Name => Eudore.Name. +// // BindHeader 函数实现使用header数据bind。 -func BindHeader(ctx Context, i interface{}) error { - return ConvertToWithTags(ctx.Request().Header, i, DefaultBindHeaderTags) +// +// header名称前缀必须是'X-',example: X-Euduore-Name => Eudore.Name。 +func BindHeader(ctx Context, i any) error { + for key, vals := range ctx.Request().Header { + if strings.HasPrefix(key, "X-") { + key = strings.ReplaceAll(key[2:], "-", ".") + for _, val := range vals { + SetAnyByPathWithTag(i, key, val, DefaultHandlerBindHeaderTags, false) + } + } + } + return nil } -// NewRenders 方法定义默认和Accepte Header映射Render函数。 +// The NewRenders method defines the default HeaderAccept value mapping Render function. +// +// The HeaderAccept value ignores non-zero weight values, and the order takes precedence. +// +// NewRenders 方法定义默认HeaderAccept值映射Render函数。 +// +// HeaderAccept值忽略非零权重值,顺序优先。 func NewRenders(renders map[string]HandlerDataFunc) HandlerDataFunc { if renders == nil { renders = map[string]HandlerDataFunc{ + MimeText: RenderText, + MimeTextPlain: RenderText, + MimeTextHTML: RenderHTML, MimeApplicationJSON: RenderJSON, MimeApplicationProtobuf: RenderProtobuf, MimeApplicationXML: RenderXML, - MimeTextHTML: RenderHTML, - MimeTextPlain: RenderText, } } - return func(ctx Context, i interface{}) error { + render, ok := renders["*"] + if !ok { + render = DefaultHandlerRenderFunc + } + return func(ctx Context, i any) error { for _, accept := range strings.Split(ctx.GetHeader(HeaderAccept), ",") { - pos := strings.IndexByte(accept, ';') - if pos != -1 { - accept = accept[:pos] + name, quality, ok := strings.Cut(strings.TrimSpace(accept), ";") + if ok && quality == "q=0" { + continue } - fn, ok := renders[strings.TrimSpace(accept)] + + fn, ok := renders[name] if ok && fn != nil { + h := ctx.Response().Header() + v := h.Values(HeaderVary) + h.Set(HeaderVary, strings.Join(append(v, HeaderAccept), ", ")) err := fn(ctx, i) - if err != ErrRenderHandlerSkip { + if !errors.Is(err, ErrRenderHandlerSkip) { return err } + h[HeaderVary] = v } } - return DefaultRenderFunc(ctx, i) + return render(ctx, i) } } @@ -148,10 +208,14 @@ func renderSetContentType(ctx Context, mime string) { } } +// The RenderJSON function uses the encoding/json library to implement json deserialization. +// +// If the request Accept is not "application/json", output in json indent format. +// // RenderJSON 函数使用encoding/json库实现json反序列化。 // // 如果请求Accept不为"application/json",使用json indent格式输出。 -func RenderJSON(ctx Context, data interface{}) error { +func RenderJSON(ctx Context, data any) error { renderSetContentType(ctx, MimeApplicationJSONCharsetUtf8) switch reflect.Indirect(reflect.ValueOf(data)).Kind() { case reflect.Struct, reflect.Map, reflect.Slice, reflect.Array: @@ -165,596 +229,93 @@ func RenderJSON(ctx Context, data interface{}) error { return encoder.Encode(data) } +// RenderXML function Render Xml, +// using the encoding/xml library to realize xml deserialization. +// // RenderXML 函数Render Xml,使用encoding/xml库实现xml反序列化。 -func RenderXML(ctx Context, data interface{}) error { +func RenderXML(ctx Context, data any) error { renderSetContentType(ctx, MimeApplicationXMLCharsetUtf8) return xml.NewEncoder(ctx).Encode(data) } +// RenderText function Render Text, written using the fmt.Fprint function. +// // RenderText 函数Render Text,使用fmt.Fprint函数写入。 -func RenderText(ctx Context, data interface{}) error { +func RenderText(ctx Context, data any) error { renderSetContentType(ctx, MimeTextPlainCharsetUtf8) if s, ok := data.(string); ok { - return ctx.WriteString(s) + ctx.WriteString(s) + return nil } if s, ok := data.(fmt.Stringer); ok { - return ctx.WriteString(s.String()) + ctx.WriteString(s.String()) + return nil } _, err := fmt.Fprintf(ctx, "%#v", data) return err } +// RenderProtobuf function Render Protobuf, +// using the built-in protobuf encoding, invalid properties will be ignored. +// // RenderProtobuf 函数Render Protobuf,使用内置protobuf编码,无效属性将忽略。 -func RenderProtobuf(ctx Context, i interface{}) error { +func RenderProtobuf(ctx Context, i any) error { renderSetContentType(ctx, MimeApplicationProtobuf) return NewProtobufEncoder(ctx).Encode(i) } +// The RenderHTML function creates a template Renderer using a template. +// +// Load *template.Template from ctx.Value(eudore.ContextKeyTemplate), +// Load the template function from ctx.GetParam("template"). +// // RenderHTML 函数使用模板创建一个模板Renderer。 // // 从ctx.Value(eudore.ContextKeyTemplate)加载*template.Template, // 从ctx.GetParam("template")加载模板函数。 -func RenderHTML(ctx Context, data interface{}) error { - t, ok := ctx.Value(ContextKeyTemplate).(*template.Template) - if ok { - // 模板必须加载name,防止渲染空模板 - name := ctx.GetParam("template") - if name != "" { - t = t.Lookup(name) - if t != nil { - renderSetContentType(ctx, MimeTextHTMLCharsetUtf8) - return t.Execute(ctx, data) - } +func RenderHTML(ctx Context, data any) error { + tpl, ok := ctx.Value(ContextKeyTemplate).(*template.Template) + if !ok { + return ErrRenderHandlerSkip + } + + name := ctx.GetParam("template") + if name == "" { + // 默认模板 + name = DefaultTemplateNameRenderData + b := bytes.NewBuffer([]byte{}) + en := json.NewEncoder(b) + en.SetEscapeHTML(false) + en.SetIndent("", "\t") + err := en.Encode(data) + if err != nil { + b.WriteString(err.Error()) } - } - - if DefaultRenderHTMLTemplate != nil { - b, _ := json.MarshalIndent(data, "", "\t") - renderSetContentType(ctx, MimeTextHTMLCharsetUtf8) - return DefaultRenderHTMLTemplate.Execute(ctx, map[string]interface{}{ - "Method": ctx.Method(), - "Host": ctx.Host(), - "Path": ctx.Request().RequestURI, - "Query": ctx.Querys(), - "Status": fmt.Sprintf("%d %s", ctx.Response().Status(), http.StatusText(ctx.Response().Status())), - "RemoteAddr": ctx.RealIP(), + data = map[string]any{ + "Method": ctx.Method(), + "Host": ctx.Host(), + "Path": ctx.Request().RequestURI, + "Query": ctx.Querys(), + "Status": fmt.Sprintf("%d %s", + ctx.Response().Status(), + http.StatusText(ctx.Response().Status()), + ), + "RemoteAddr": ctx.Host(), + "LocalAddr": ctx.RealIP(), "Params": ctx.Params(), "RequestHeader": ctx.Request().Header, "ResponseHeader": ctx.Response().Header(), - "Data": string(b), + "Data": b.String(), "GodocServer": DefaultGodocServer, "TraceServer": DefaultTraceServer, - }) - } - return ErrRenderHandlerSkip -} - -var renderHTMLTempdate = ` - - - - Eudore Look Value - - - - - - - -
-

Eudore default render html

-
- General -
Request URL: {{.Host}}{{.Path}}
-
Request Method: {{.Method}}
-
Status Code: {{.Status}}
-
Remote Address: {{.RemoteAddr}}
-
- {{- if ne (len .Query) 0 }} -
- Requesst Querys - {{- range $key, $vals := .Query -}} - {{- range $i, $val := $vals }} -
{{$key}}: {{$val}}
- {{- end }} - {{- end }} -
- {{- end }} -
- Requesst Params - {{- $iskey := true }} - {{- range $i,$val := .Params}} - {{- if $iskey}} -
{{$val}}: {{- else}}{{$val}}
{{end}} - {{- $iskey = not $iskey}} - {{- end}} -
-
- Request Headers - {{- range $key, $vals := .RequestHeader -}} - {{- range $i, $val := $vals }} -
{{$key}}: {{$val}}
- {{- end }} - {{- end }} -
-
- {{- $trace := .TraceServer }} - Response Headers - {{- range $key, $vals := .ResponseHeader -}} - {{- range $i, $val := $vals }} - {{- if and (eq $key "X-Trace-Id") (ne $trace "")}} -
{{$key}}: {{$val}}
- {{- else }} -
{{$key}}: {{$val}}
- {{- end }} - {{- end }} - {{- end }} -
-
- Response Data -
{{.Data}}
-
-
- -` - -// NewValidateField 方法创建结构体属性校验器。 -// -// 使用结构体tag validate从FuncCreator获取校验函数。 -// 获取ContextKeyFuncCreator.(FuncCreator)用于创建校验函数。 -func NewValidateField(ctx context.Context) HandlerDataFunc { - fc, ok := ctx.Value(ContextKeyFuncCreator).(FuncCreator) - if !ok { - fc = DefaultFuncCreator - } - validater := &validateField{ - ValidateTypes: make(map[reflect.Type][]validateFieldValue), - FuncCreator: fc, - } - return func(_ Context, i interface{}) error { - return validater.Validate(i) - } -} - -type validateField struct { - sync.Map - ValidateTypes map[reflect.Type][]validateFieldValue - FuncCreator FuncCreator -} - -type validateFieldValue struct { - Index int - Value reflect.Value - Format string -} - -// Validate 方法校验一个对象属性。 -// -// 允许类型为struct []struct []*struct []interface -func (v *validateField) Validate(i interface{}) error { - iValue := reflect.Indirect(reflect.ValueOf(i)) - switch iValue.Kind() { - case reflect.Struct: - return v.validate(iValue, nil) - case reflect.Slice, reflect.Array: - switch iValue.Type().Elem().Kind() { - case reflect.Struct: - // []struct - vfs, err := v.parseStructFields(iValue.Type().Elem()) - if err != nil || len(vfs) == 0 { - return err - } - for i := 0; i < iValue.Len(); i++ { - err = v.validate(iValue.Index(i), vfs) - if err != nil { - return err - } - } - case reflect.Interface, reflect.Ptr: - // []*struct - // []interface{}{*structA} - for i := 0; i < iValue.Len(); i++ { - field := reflect.Indirect(iValue.Index(i)) - if field.Kind() == reflect.Struct { - err := v.validate(field, nil) - if err != nil { - return err - } - } - } - } - } - return nil -} - -func (v *validateField) validate(iValue reflect.Value, vfs []validateFieldValue) error { - if vfs == nil { - var err error - vfs, err = v.parseStructFields(iValue.Type()) - if err != nil { - return err } } - // 匹配验证器规则 - for _, i := range vfs { - field := iValue.Field(i.Index) - // 反射调用Validater检测函数 - out := i.Value.Call([]reflect.Value{field}) - if !out[0].Bool() { - return fmt.Errorf(i.Format, field.Interface()) - } - } - return nil -} - -func (v *validateField) parseStructFields(iType reflect.Type) ([]validateFieldValue, error) { - data, ok := v.Load(iType) - if ok { - return data.([]validateFieldValue), nil + tpl = tpl.Lookup(name) + if tpl == nil { + return ErrRenderHandlerSkip } - var vfs []validateFieldValue - for i := 0; i < iType.NumField(); i++ { - field := iType.Field(i) - tags := field.Tag.Get(DefaultValidateTag) - // 解析tags - for _, tag := range strings.Split(tags, " ") { - if tag == "" { - continue - } - fn, err := v.FuncCreator.Create(field.Type, tag) - if err != nil { - return nil, fmt.Errorf(ErrFormatParseValidateFieldError, iType.PkgPath(), iType.Name(), field.Name, tag, err.Error()) - } - vfs = append(vfs, validateFieldValue{ - Index: i, - Value: reflect.ValueOf(fn), - Format: fmt.Sprintf("validate %s.%s field %s check %s rule fatal, value: %%#v", iType.PkgPath(), iType.Name(), field.Name, tag), - }) - } - } - - v.Store(iType, vfs) - return vfs, nil -} - -// FuncCreator 定义校验函数构造器,默认由RouterStd和validateField使用。 -type FuncCreator interface { - Register(string, ...interface{}) error - Create(reflect.Type, string) (interface{}, error) -} - -type funcCreator struct { - sync.RWMutex - Logger Logger - // 验证规则 - 验证类型 - 验证函数 - FuncValues map[string]map[reflect.Type]interface{} - // 验证规则 - 验证生成函数 - FuncNews map[string]map[reflect.Type]reflect.Value -} - -// NewFuncCreator 函数创建默认校验函数构造器。 -func NewFuncCreator() FuncCreator { - fc := &funcCreator{ - Logger: NewLoggerNull(), - FuncValues: make(map[string]map[reflect.Type]interface{}), - FuncNews: make(map[string]map[reflect.Type]reflect.Value), - } - fc.initFunc() - return fc -} - -// Mount 方法获取ContextKeyApp.(Logger)作为默认日志输出。 -func (fc *funcCreator) Mount(ctx context.Context) { - logger, ok := ctx.Value(ContextKeyApp).(Logger) - if ok { - fc.Logger = logger.WithField("creator", "funcCreator").WithField("depth", 2).WithField("logger", true) - fc.initFunc() - } -} - -func (fc *funcCreator) initFunc() { - fc.Register("nozero", validateIntNozero, validateStringNozero, validateInterfaceNozero) - fc.Register("isnum", validateStringIsnum) - fc.Register("min", validateNewIntMin, validateNewStringMin) - fc.Register("max", validateNewIntMax, validateNewStringMax) - fc.Register("len", validateNewStringLen, validateNewInterfaceLen) - fc.Register("regexp", validateNewStringRegexp) -} - -// Register 函数给一个名称注册多个类型的的ValidateFunc或ValidateNewFunc。 -// -// ValidateFunc func(T) bool -// -// ValidateNewFunc func(string) (func(T) bool, error) -func (fc *funcCreator) Register(name string, fns ...interface{}) error { - fc.Lock() - defer fc.Unlock() - var errs errormulit - for _, fn := range fns { - errs.HandleError(fc.registerFunc(name, fn)) - } - return errs.Unwrap() -} - -// registerFunc 函数注册一个ValidateFunc或ValidateNewFunc -func (fc *funcCreator) registerFunc(name string, fn interface{}) error { - iType := reflect.TypeOf(fn) - - if checkValidateFunc(iType) { - if fc.FuncValues[name] == nil { - fc.FuncValues[name] = make(map[reflect.Type]interface{}) - } - fc.FuncValues[name][iType.In(0)] = fn - fc.Logger.Debugf("Register func %s %T", name, fn) - return nil - } - - if iType.Kind() == reflect.Func && iType.NumIn() == 1 && iType.NumOut() == 2 && iType.In(0) == typeString && iType.Out(1) == typeError { - fType := iType.Out(0) - if checkValidateFunc(fType) { - if fc.FuncNews[name] == nil { - fc.FuncNews[name] = make(map[reflect.Type]reflect.Value) - } - fc.FuncNews[name][fType.In(0)] = reflect.ValueOf(fn) - return nil - } - } - - err := fmt.Errorf(ErrFormatFuncCreatorRegisterInvalidType, name, fn) - fc.Logger.Error(err) - return err -} - -// Create 方法获取或创建一个校验函数。 -// func(Type) bool/ func(interface{}) bool/ error/ func(Type) Func -func (fc *funcCreator) Create(iType reflect.Type, fullname string) (interface{}, error) { - fc.RLock() - fvs, ok := fc.FuncValues[fullname] - if ok { - fn, ok := fvs[iType] - if ok { - fc.RUnlock() - return fn, nil - } - } - fc.RUnlock() - - // 升级锁 - fc.Lock() - defer fc.Unlock() - - name, arg := getValidateNameArg(fullname) - if arg != "" { - fns, ok := fc.FuncNews[name] - if ok { - fn, ok := fns[iType] - if ok { - vals := fn.Call([]reflect.Value{reflect.ValueOf(arg)}) - fn, err := vals[0].Interface(), vals[1].Interface() - if err != nil { - fc.Logger.Errorf("Create func %s error: %v", fullname, err) - return nil, err.(error) - } - fc.registerFunc(fullname, fn) - return fn, nil - } - } - } - - return nil, fmt.Errorf(ErrFormatFuncCreatorNotFunc, fullname) -} - -func checkValidateFunc(iType reflect.Type) bool { - if iType.Kind() != reflect.Func { - return false - } - if iType.NumIn() != 1 || iType.NumOut() != 1 { - return false - } - if iType.Out(0) != typeBool { - return false - } - return true -} - -func getValidateNameArg(name string) (string, string) { - for i, b := range name { - // ! [0-9A-Za-z] - if b < 0x30 || (0x39 < b && b < 0x41) || (0x5A < b && b < 0x61) || 0x7A < b { - return name[:i], name[i:] - } - } - return name, "" -} -func getValidateNameNumber(name string) string { - var number string - for i, b := range name { - if 0x2F < b && b < 0x3A { - number += name[i : i+1] - } - } - return number -} - -// validateIntNozero 函数验证一个int是否为零 -func validateIntNozero(num int) bool { - return num != 0 -} - -// validateStringNozero 函数验证一个字符串是否为空 -func validateStringNozero(str string) bool { - return str != "" -} - -// validateInterfaceNozero 函数验证一个对象是否为零值,使用reflect.Value.IsZero函数实现。 -func validateInterfaceNozero(i interface{}) bool { - return !reflect.ValueOf(i).IsZero() -} - -// validateStringIsnum 函数验证一个字符串是否为数字。 -func validateStringIsnum(str string) bool { - _, err := strconv.Atoi(str) - return err == nil -} - -// validateNewIntMin 函数生成一个验证int最小值的验证函数。 -func validateNewIntMin(str string) (func(int) bool, error) { - str = getValidateNameNumber(str) - min, err := strconv.ParseInt(str, 10, 32) - if err != nil { - return nil, err - } - intmin := int(min) - return func(num int) bool { - return num >= intmin - }, nil -} - -// validateNewIntMax 函数生成一个验证int最大值的验证函数。 -func validateNewIntMax(str string) (func(int) bool, error) { - str = getValidateNameNumber(str) - max, err := strconv.ParseInt(str, 10, 32) - if err != nil { - return nil, err - } - intmax := int(max) - return func(num int) bool { - return num <= intmax - }, nil -} - -// validateNewStringMin 函数生成一个验证string最小值的验证函数。 -func validateNewStringMin(str string) (func(string) bool, error) { - str = getValidateNameNumber(str) - min, err := strconv.ParseInt(str, 10, 32) - if err != nil { - return nil, err - } - intmin := int(min) - return func(arg string) bool { - num, err := strconv.Atoi(arg) - if err != nil { - return false - } - return num >= intmin - }, nil -} - -// validateNewStringMax 函数生成一个验证string最大值的验证函数。 -func validateNewStringMax(str string) (func(string) bool, error) { - str = getValidateNameNumber(str) - max, err := strconv.ParseInt(str, 10, 32) - if err != nil { - return nil, err - } - intmax := int(max) - return func(arg string) bool { - num, err := strconv.Atoi(arg) - if err != nil { - return false - } - return num <= intmax - }, nil -} - -// validateNewStringLen 函数生一个验证字符串长度'>','<','='指定长度的验证函数。 -func validateNewStringLen(str string) (func(s string) bool, error) { - var flag string - for _, i := range []string{">", "<", "=", ""} { - if strings.HasPrefix(str, i) { - flag = i - str = str[len(i):] - break - } - } - - length, err := strconv.ParseInt(str, 10, 32) - if err != nil { - return nil, err - } - intlength := int(length) - switch flag { - case ">": - return func(s string) bool { - return len(s) > intlength - }, nil - case "<": - return func(s string) bool { - return len(s) < intlength - }, nil - default: - return func(s string) bool { - return len(s) == intlength - }, nil - } -} - -// validateNewInterfaceLen 函数生一个验证字符串长度'>','<','='指定长度的验证函数。 -func validateNewInterfaceLen(str string) (func(i interface{}) bool, error) { - var flag string - for _, i := range []string{">", "<", "=", ""} { - if strings.HasPrefix(str, i) { - flag = i - str = str[len(i):] - break - } - } - - length, err := strconv.ParseInt(str, 10, 32) - if err != nil { - return nil, err - } - intlength := int(length) - switch flag { - case ">": - return func(i interface{}) bool { - iValue := reflect.Indirect(reflect.ValueOf(i)) - switch iValue.Kind() { - case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: - return iValue.Len() > intlength - default: - return false - } - }, nil - case "<": - return func(i interface{}) bool { - iValue := reflect.Indirect(reflect.ValueOf(i)) - switch iValue.Kind() { - case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: - return iValue.Len() < intlength - default: - return false - } - }, nil - default: - return func(i interface{}) bool { - iValue := reflect.Indirect(reflect.ValueOf(i)) - switch iValue.Kind() { - case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: - return iValue.Len() == intlength - default: - return false - } - }, nil - } -} - -// validateNewStringRegexp 函数生成一个正则检测字符串的验证函数。 -func validateNewStringRegexp(str string) (func(arg string) bool, error) { - re, err := regexp.Compile(str) - if err != nil { - return nil, err - } - // 返回正则匹配校验函数 - return func(arg string) bool { - return re.MatchString(arg) - }, nil + renderSetContentType(ctx, MimeTextHTMLCharsetUtf8) + return tpl.Execute(ctx, data) } diff --git a/handlerdata2.go b/handlerdata2.go new file mode 100644 index 0000000..0b79b96 --- /dev/null +++ b/handlerdata2.go @@ -0,0 +1,328 @@ +package eudore + +import ( + "context" + "fmt" + "reflect" + "strings" + "sync" +) + +type validateField struct { + sync.Map + FuncCreator FuncCreator +} + +type validateFieldValue struct { + Index int + Omit bool + Func FuncRunner + Format string +} + +// The NewValidateField method creates a struct property validator. +// +// Get ContextKeyFuncCreator.(FuncCreator) to create a verification function. +// Use the structure tag validate to get the verification function from FuncCreator. +// +// Allowed types are struct []struct []*struct []interface. +// +// Only verify that the field type is string/int/uint/float/bool/any, +// and the int-related numerical type is converted to int and then verified, +// and the precision may be lost. +// +// NewValidateField 方法创建结构体属性校验器。 +// +// 获取ContextKeyFuncCreator.(FuncCreator)用于创建校验函数。 +// 使用结构体tag validate从FuncCreator获取校验函数。 +// +// 允许类型为struct []struct []*struct []interface。 +// +// 仅校验字段类型为string/int/uint/float/bool/any,int相关数值类型转换成int后校验,可能精度丢失。 +func NewValidateField(ctx context.Context) HandlerDataFunc { + vf := &validateField{FuncCreator: NewFuncCreatorWithContext(ctx)} + return func(ctx Context, i any) error { + c := ctx.GetContext() + v := reflect.Indirect(reflect.ValueOf(i)) + switch v.Kind() { + case reflect.Struct: + return vf.validateFields(c, i, v) + case reflect.Slice, reflect.Array: + // []struct []*struct []any + for i := 0; i < v.Len(); i++ { + field := v.Index(i) + for field.Kind() == reflect.Ptr || field.Kind() == reflect.Interface { + field = field.Elem() + } + if field.Kind() == reflect.Struct { + err := vf.validateFields(c, v.Index(i), field) + if err != nil { + return err + } + } + } + } + return nil + } +} + +func (vf *validateField) validateFields(c context.Context, i any, v reflect.Value) error { + fields, err := vf.parseStructFields(v.Type()) + if err != nil { + return err + } + + if validater, ok := i.(interface{ Validate(context.Context) error }); ok { + if err := validater.Validate(c); err != nil { + return err + } + } + + // 匹配验证器规则 + for _, i := range fields { + field := v.Field(i.Index) + if i.Omit && field.IsZero() { + continue + } + if !i.Func.RunPtr(field) { + return fmt.Errorf(i.Format, field.Interface()) + } + } + return nil +} + +func (vf *validateField) parseStructFields(iType reflect.Type) ([]validateFieldValue, error) { + data, ok := vf.Load(iType) + if ok { + switch val := data.(type) { + case []validateFieldValue: + return val, nil + case error: + return nil, val + } + } + + var fields []validateFieldValue + for i := 0; i < iType.NumField(); i++ { + t := iType.Field(i) + tags, omit := cutOmit(t.Tag.Get(DefaultHandlerValidateTag)) + if tags == "-" { + continue + } + + if t.Anonymous { + et := t.Type + if et.Kind() == reflect.Ptr { + et = et.Elem() + } + if et.Kind() == reflect.Struct { + f, err := vf.parseStructFields(et) + if err != nil { + vf.Store(iType, err) + return nil, err + } + fields = append(fields, f...) + continue + } + } + + for _, tag := range splitValidateTag(tags) { + kind := NewFuncCreateKindWithType(t.Type) + fn, err := vf.FuncCreator.CreateFunc(kind, tag) + if err != nil { + err = fmt.Errorf(ErrFormatValidateParseFieldError, + iType.PkgPath(), iType.Name(), t.Name, tag, err.Error()) + vf.Store(iType, err) + return nil, err + } + + val := validateFieldValue{ + Index: i, Omit: omit, Func: FuncRunner{kind, fn}, + Format: fmt.Sprintf(ErrFormatValidateErrorFormat, + iType.PkgPath(), iType.Name(), t.Name, tag), + } + fields = append(fields, val) + } + } + + vf.Store(iType, fields) + return fields, nil +} + +func splitValidateTag(s string) []string { + var strs []string + var last int + var block int + for i := range s { + switch s[i] { + case '(': + block++ + case ')': + block-- + case ',': + if block == 0 { + strs = append(strs, s[last:i]) + last = i + 1 + } + } + } + strs = append(strs, s[last:]) + for i, str := range strs { + if str != "" && str[0] == '(' && str[len(str)-1] == ')' { + strs[i] = str[1 : len(str)-1] + } + } + strs = sliceFilter(strs, func(s string) bool { return s != "" }) + return strs +} + +type filterRule struct { + FuncCreator + Rules sync.Map +} + +// FilterData defines Filter data matching objects and rules. +// +// Package and Name define the filter structure object name, +// you can use '*' fuzzy matching. +// +// Checks and Modifys define data matching and modification behavior functions. +// If Modifys is not defined, the entire data object will be empty. +// +// FilterData 定义Filter数据匹配对象和规则。 +// +// Package和Name定义过滤结构体对象名称,可以使用'*'模糊匹配。 +// +// Checks和Modifys定义数据匹配和修改行为函数,如果未定义Modifys将整个数据对象置空。 +type FilterData struct { + Package string `alias:"package" json:"package" xml:"package" yaml:"package"` + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` + Checks []string `alias:"checks" json:"checks" xml:"checks" yaml:"checks"` + Modifys []string `alias:"modifys" json:"modifys" xml:"modifys" yaml:"modifys"` +} + +// NewFilterRules function creates FilterData filter function. +// +// Load filter rules from ctx.Value(ContextKeyFilterRules). +// +// NewFilterRules 函数创建FilterData过滤函数。 +// +// 从ctx.Value(ContextKeyFilterRules)加载过滤规则。 +func NewFilterRules(c context.Context) HandlerDataFunc { + filter := &filterRule{FuncCreator: NewFuncCreatorWithContext(c)} + return func(ctx Context, i any) error { + switch rule := ctx.Value(ContextKeyFilterRules).(type) { + case []string: + filter.filte(reflect.ValueOf(i), &FilterData{Checks: rule}) + case *FilterData: + filter.filte(reflect.ValueOf(i), rule) + case []FilterData: + for i := range rule { + filter.filte(reflect.ValueOf(i), &rule[i]) + } + } + return nil + } +} + +func (f *filterRule) filte(v reflect.Value, data *FilterData) { + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + if !v.IsNil() { + f.filte(v.Elem(), data) + } + case reflect.Struct: + if v.CanSet() && data.matchName(v.Type()) { + f.filteData(v, data) + } + case reflect.Slice, reflect.Array: + eType := v.Type().Elem() + for eType.Kind() == reflect.Ptr { + eType = eType.Elem() + } + + if eType.Kind() == reflect.Struct { + if data.matchName(eType) { + for i := 0; i < v.Len(); i++ { + f.filteData(v.Index(i), data) + } + } + } else { + for i := 0; i < v.Len(); i++ { + f.filte(v.Index(i), data) + } + } + } +} + +func (d *FilterData) matchName(iType reflect.Type) bool { + return matchStarWithEmpty(d.Package, iType.PkgPath()) && matchStarWithEmpty(d.Name, iType.Name()) +} + +func (f *filterRule) filteData(v reflect.Value, data *FilterData) { + if f.filteRules(v, data.Checks, 0) { + if len(data.Modifys) == 0 { + v.Set(reflect.Zero(v.Type())) + } else { + f.filteRules(v, data.Modifys, FuncCreateNumber) + } + } +} + +func (f *filterRule) filteRules(v reflect.Value, rules []string, kind FuncCreateKind) bool { + for _, rule := range rules { + key, val, _ := strings.Cut(rule, "=") + field, err := getValue(v, key, []string{"filter", "alias"}, false) + if err != nil { + continue + } + + k := NewFuncCreateKindWithType(field.Type()) + if k == FuncCreateInvalid { + continue + } + k += kind + + r, ok := f.Rules.Load(k.String() + rule) + if !ok { + fn, err := f.FuncCreator.CreateFunc(k, val) + if err == nil { + r = &FuncRunner{Kind: k, Func: fn} + } else { + r = &FuncRunner{} + } + f.Rules.Store(k.String()+rule, r) + } + + b := r.(*FuncRunner).RunPtr(field) + if !b { + return false + } + } + return true +} + +// matchStar 模式匹配对象,允许使用带'*'的模式。 +func matchStarWithEmpty(patten, obj string) bool { + if patten == "" { + return true + } + parts := strings.Split(patten, "*") + if len(parts) < 2 { + return patten == obj + } + if !strings.HasPrefix(obj, parts[0]) { + return false + } + for _, i := range parts { + if i == "" { + continue + } + pos := strings.Index(obj, i) + if pos == -1 { + return false + } + obj = obj[pos+len(i):] + } + return true +} diff --git a/handlerextender.go b/handlerextender.go new file mode 100644 index 0000000..0460c0f --- /dev/null +++ b/handlerextender.go @@ -0,0 +1,734 @@ +package eudore + +import ( + "context" + "fmt" + "net/http" + "reflect" + "runtime" + "strings" +) + +// HandlerExtender defines the method of extending the function handler. +// +// The HandlerExtender object has three default implementations, Base, Warp, and Tree, +// which are defined in the HandlerExtender interface. +// +// HandlerExtender 定义函数扩展处理者的方法。 +// +// HandlerExtender默认拥有Base、Warp、Tree三种实现,具体参数三种对象的文档。 +type HandlerExtender interface { + RegisterExtender(string, any) error + CreateHandler(string, any) HandlerFuncs + List() []string +} + +// HandlerFuncs is a collection of HandlerFunc, +// representing multiple request processing functions. +// +// handlerExtenderBase 定义基础的函数扩展。 +type handlerExtenderBase struct { + NewType []reflect.Type + NewFunc []reflect.Value + AnyType []reflect.Type + AnyFunc []reflect.Value +} + +// handlerExtenderWarp 定义链式函数扩展。 +type handlerExtenderWarp struct { + data HandlerExtender + last HandlerExtender +} + +// handlerExtenderTree 定义基于路径匹配的函数扩展。 +type handlerExtenderTree struct { + data HandlerExtender + path string + childs []*handlerExtenderTree +} + +type MetadataHandlerExtender struct { + Health bool `alias:"health" json:"health" xml:"health" yaml:"health"` + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` + Extender []string `alias:"extender" json:"extender" xml:"extender" yaml:"extender"` +} + +var ( + // contextFuncName key类型一定为HandlerFunc类型,保存函数可能正确的名称。 + contextFuncName = make(map[uintptr]string) // 最终名称 + contextSaveName = make(map[uintptr]string) // 函数名称 + contextAliasName = make(map[uintptr][]string) // 对象名称 + fineLineFieldsKeys = []string{"file", "line"} +) + +// NewHandlerExtender 函数创建默认HandlerExtender并加载默认扩展函数。 +func NewHandlerExtender() HandlerExtender { + he := NewHandlerExtenderBase() + for _, fn := range DefaultHandlerExtenderFuncs { + _ = he.RegisterExtender("", fn) + } + return he +} + +func NewHandlerExtenderWithContext(ctx context.Context) HandlerExtender { + he, ok := ctx.Value(ContextKeyHandlerExtender).(HandlerExtender) + if ok { + return he + } + return DefaultHandlerExtender +} + +// NewHandlerExtenderBase method returns a basic function extension processing object. +// +// The NewHandlerExtenderBase().RegisterExtender method registers +// a conversion function and ignores the path. +// +// The NewHandlerExtenderBase().CreateHandler method implementation +// creates multiple request handler functions, ignoring paths. +// +// NewHandlerExtenderBase 方法返回一个基本的函数扩展处理对象。 +// +// NewHandlerExtenderBase().RegisterExtender 方法实现注册一个转换函数,忽略路径。 +// +// NewHandlerExtenderBase().CreateHandler 方法实现创建多个请求处理函数,忽略路径。 +func NewHandlerExtenderBase() HandlerExtender { + return &handlerExtenderBase{} +} + +// RegisterExtender function registers a request context handling conversion function. +// The parameter must be a function that takes a function, an interface, or a pointer type +// as a parameter and returns a HandlerFunc object. +// +// If multiple interface type conversions are added, the registration type is not directly the interface +// but the implementation interface, and the implementation interface will be checked +// in the order of interface registration. +// +// For example: func(func(...)) HanderFunc, func(http.Handler) HandlerFunc +// +// RegisterExtender 函数注册一个请求上下文处理转换函数,参数必须是一个函数, +// 该函数的参数必须是一个函数、接口、指针类型之一,返回值必须是返回一个HandlerFunc对象。 +// +// 如果添加多个接口类型转换,注册类型不直接是接口而是实现接口,会按照接口注册顺序依次检测是否实现接口。 +// +// 例如: func(func(...)) HanderFunc, func(http.Handler) HandlerFunc. +func (he *handlerExtenderBase) RegisterExtender(_ string, fn any) error { + iType := reflect.TypeOf(fn) + // RegisterExtender函数的参数必须是一个函数类型 + if iType.Kind() != reflect.Func { + return ErrHandlerExtenderParamNotFunc + } + + // 检查函数参数必须为 func(Type) 或 func(string, Type), + // 允许使用的type值定义在DefaultHandlerExtendAllowType。 + if (iType.NumIn() != 1) && (iType.NumIn() != 2 || iType.In(0).Kind() != reflect.String) { + return fmt.Errorf(ErrFormatHandlerExtenderInputParamError, iType.String()) + } + _, ok := DefaultHandlerExtenderAllowType[iType.In(iType.NumIn()-1).Kind()] + if !ok { + return fmt.Errorf(ErrFormatHandlerExtenderInputParamError, iType.String()) + } + + // 检查函数返回值必须是HandlerFunc + if iType.NumOut() != 1 || iType.Out(0) != typeHandlerFunc { + return fmt.Errorf(ErrFormatHandlerExtenderOutputParamError, iType.String()) + } + + he.NewType = append(he.NewType, iType.In(iType.NumIn()-1)) + he.NewFunc = append(he.NewFunc, reflect.ValueOf(fn)) + if iType.In(iType.NumIn()-1).Kind() == reflect.Interface { + he.AnyType = append(he.AnyType, iType.In(iType.NumIn()-1)) + he.AnyFunc = append(he.AnyFunc, reflect.ValueOf(fn)) + } + return nil +} + +// CreateHandler 函数根据参数返回一个HandlerFuncs。 +func (he *handlerExtenderBase) CreateHandler(path string, i any) HandlerFuncs { + val, ok := i.(reflect.Value) + if !ok { + val = reflect.ValueOf(i) + } + return NewHandlerFuncsFilter(he.newHandlerFuncs(path, val)) +} + +func (he *handlerExtenderBase) newHandlerFuncs(path string, v reflect.Value) HandlerFuncs { + // 基础类型返回 + switch fn := v.Interface().(type) { + case func(Context): + SetHandlerFuncName(fn, getHandlerAliasName(v)) + return HandlerFuncs{fn} + case HandlerFunc: + SetHandlerFuncName(fn, getHandlerAliasName(v)) + return HandlerFuncs{fn} + case []HandlerFunc: + return fn + case HandlerFuncs: + return fn + } + // 尝试转换成HandlerFuncs + fn := he.newHandlerFunc(path, v) + if fn != nil { + return HandlerFuncs{fn} + } + // 解引用数组再转换HandlerFuncs + switch v.Type().Kind() { + case reflect.Slice, reflect.Array: + var fns HandlerFuncs + for i := 0; i < v.Len(); i++ { + hs := he.newHandlerFuncs(path, v.Index(i)) + if hs != nil { + fns = append(fns, hs...) + } + } + if len(fns) != 0 { + return fns + } + case reflect.Interface, reflect.Ptr: + return he.newHandlerFuncs(path, v.Elem()) + } + return nil +} + +// The newHandlerFunc function converts a function or interface parameter into +// a request context handler function. +// +// The parameter must be a function, the function has a parameter as an input parameter, +// and a HandlerFunc object as a return value. +// +// First check whether the object has a directly registered type extension function, +// and then check whether the object implements the registered interface type. +// +// Multiple registrations are allowed, as long as the registration return value is not empty, +// the corresponding processing function will be returned. +// +// newHandlerFunc 函数使用一个函数或接口参数转换成请求上下文处理函数。 +// +// 参数必须是一个函数,函数拥有一个参数作为入参,一个HandlerFunc对象作为返回值。 +// +// 先检测对象是否拥有直接注册的类型扩展函数,再检查对象是否实现其中注册的接口类型。 +// +// 允许进行多次注册,只要注册返回值不为空就会返回对应的处理函数。 +func (he *handlerExtenderBase) newHandlerFunc(path string, v reflect.Value) HandlerFunc { + iType := v.Type() + for i := range he.NewType { + if he.NewType[i] == iType { + h := he.createHandlerFunc(path, he.NewFunc[i], v) + if h != nil { + return h + } + } + } + // 判断是否实现接口类型 + for i, iface := range he.AnyType { + if iType.Implements(iface) { + h := he.createHandlerFunc(path, he.AnyFunc[i], v) + if h != nil { + return h + } + } + } + return nil +} + +// The createHandlerFunc function creates a HandlerFunc using the conversion function and object, +// and saves the name of the HandlerFunc and the name of the extended function used. +// createHandlerFunc 函数使用转换函数和对象创建一个HandlerFunc,并保存HandlerFunc的名称和使用的扩展函数名称。 +func (he *handlerExtenderBase) createHandlerFunc(path string, fn, v reflect.Value) (h HandlerFunc) { + if fn.Type().NumIn() == 1 { + h = fn.Call([]reflect.Value{v})[0].Interface().(HandlerFunc) + } else { + h = fn.Call([]reflect.Value{reflect.ValueOf(path), v})[0].Interface().(HandlerFunc) + } + if h == nil { + return nil + } + // 获取新函数名称,一般来源于函数扩展返回的函数名称。 + hptr := getFuncPointer(reflect.ValueOf(h)) + name := contextSaveName[hptr] + // 使用原值名称 + if name == "" && v.Kind() != reflect.Struct { + name = getHandlerAliasName(v) + } + // 推断名称 + if name == "" { + iType := v.Type() + switch iType.Kind() { + case reflect.Func: + name = runtime.FuncForPC(v.Pointer()).Name() + case reflect.Ptr: + iType = iType.Elem() + name = fmt.Sprintf("*%s.%s", iType.PkgPath(), iType.Name()) + case reflect.Struct: + name = fmt.Sprintf("%s.%s", iType.PkgPath(), iType.Name()) + default: + name = "any" + } + } + // 获取扩展名称,eudore包移除包前缀 + extname := strings.TrimPrefix(runtime.FuncForPC(fn.Pointer()).Name(), "github.com/eudore/eudore.") + contextFuncName[hptr] = fmt.Sprintf("%s(%s)", name, extname) + return h +} + +var formarExtendername = "%s(%s)" + +// The List method returns all registered function names. +// +// List 方法返回全部注册的函数名称。 +func (he *handlerExtenderBase) List() []string { + names := make([]string, 0, len(he.NewFunc)) + for i := range he.NewType { + if he.NewType[i].Kind() != reflect.Interface { + name := runtime.FuncForPC(he.NewFunc[i].Pointer()).Name() + names = append(names, fmt.Sprintf(formarExtendername, name, he.NewType[i].String())) + } + } + for i, iface := range he.AnyType { + name := runtime.FuncForPC(he.AnyFunc[i].Pointer()).Name() + names = append(names, fmt.Sprintf(formarExtendername, name, iface.String())) + } + return names +} + +func (he *handlerExtenderBase) Metadata() any { + return MetadataHandlerExtender{ + Health: true, + Name: "eudore.handlerExtenderBase", + Extender: he.List(), + } +} + +// NewHandlerExtenderWarp function creates a chained HandlerExtender object. +// +// All objects are registered and created using base. If base cannot create a function handler, +// use last to create a function handler. +// +// The NewHandlerExtenderWarp(base, last).RegisterExtender method +// uses the base object to register extension functions. +// +// The NewHandlerExtenderWarp(base, last).CreateHandler method first +// uses the base object to create multiple request processing functions. +// If it returns nil, it uses the last object to create multiple request processing functions. +// +// NewHandlerExtenderWarp 函数创建一个链式HandlerExtender对象。 +// +// 所有对象注册和创建均使用base,如果base无法创建函数处理者则使用last创建函数处理者。 +// +// NewHandlerExtenderWarp(base, last).RegisterExtender 方法使用base对象注册扩展函数。 +// +// NewHandlerExtenderWarp(base, last).CreateHandler 方法先使用base对象创建多个请求处理函数, +// 如果返回nil,则使用last对象创建多个请求处理函数。 +func NewHandlerExtenderWarp(base, last HandlerExtender) HandlerExtender { + return &handlerExtenderWarp{ + data: base, + last: last, + } +} + +// RegisterExtender 方法基于路径注册一个扩展函数。 +func (he *handlerExtenderWarp) RegisterExtender(path string, i any) error { + return he.data.RegisterExtender(path, i) +} + +// The CreateHandler method implements the CreateHandler function. +// If the current HandlerExtender cannot create HandlerFuncs, +// it calls the superior HandlerExtender to process. +// +// CreateHandler 方法实现CreateHandler函数,如果当前HandlerExtender无法创建HandlerFuncs, +// 则调用上级HandlerExtender处理。 +func (he *handlerExtenderWarp) CreateHandler(path string, i any) HandlerFuncs { + hs := he.data.CreateHandler(path, i) + if hs != nil { + return hs + } + return he.last.CreateHandler(path, i) +} + +// List 方法返回全部注册的函数名称。 +func (he *handlerExtenderWarp) List() []string { + return append(he.last.List(), he.data.List()...) +} + +func (he *handlerExtenderWarp) Metadata() any { + return MetadataHandlerExtender{ + Health: true, + Name: "eudore.handlerExtenderWarp", + Extender: he.List(), + } +} + +// NewHandlerExtenderTree function creates a path-based function extender. +// +// Mainly implement path matching. All actions are processed by the node's HandlerExtender, +// and the NewHandlerExtenderBase () object is used. +// +// All registration and creation actions will be performed by matching the lowest node of the tree. +// If it cannot be created, the tree nodes will be processed upwards in order. +// +// The NewHandlerExtenderTree().RegisterExtender method registers a handler function based on the path, +// and initializes to NewHandlerExtenderBase () if the HandlerExtender is empty. +// +// The NewHandlerExtenderTree().CreateHandler method matches the child nodes of the tree based on the path, +// and then executes the CreateHandler method from the most child node up. +// If it returns non-null, it returns directly. +// +// NewHandlerExtenderTree 函数创建一个基于路径的函数扩展者。 +// +// 主要实现路径匹配,所有行为使用节点的HandlerExtender处理,使用NewHandlerExtenderBase()对象。 +// +// 所有注册和创建行为都会匹配树最下级节点执行,如果无法创建则在树节点依次向上处理。 +// +// NewHandlerExtenderTree().RegisterExtender 方法基于路径注册一个处理函数, +// 如果HandlerExtender为空则初始化为NewHandlerExtenderBase()。 +// +// NewHandlerExtenderTree().CreateHandler 方法基于路径向树子节点匹配, +// 后从最子节点依次向上执行CreateHandler方法,如果返回非空直接返回,否在会依次执行注册行为。 +func NewHandlerExtenderTree() HandlerExtender { + return &handlerExtenderTree{} +} + +// RegisterExtender 方法基于路径注册一个扩展函数。 +func (he *handlerExtenderTree) RegisterExtender(path string, i any) error { + // 匹配当前节点注册 + if path == "" { + if he.data == nil { + he.data = NewHandlerExtenderBase() + } + return he.data.RegisterExtender("", i) + } + + // 寻找对应的子节点注册 + for pos := range he.childs { + subStr, find := getSubsetPrefix(path, he.childs[pos].path) + if find { + if subStr != he.childs[pos].path { + he.childs[pos].path = strings.TrimPrefix(he.childs[pos].path, subStr) + he.childs[pos] = &handlerExtenderTree{ + path: subStr, + childs: []*handlerExtenderTree{he.childs[pos]}, + } + } + return he.childs[pos].RegisterExtender(strings.TrimPrefix(path, subStr), i) + } + } + + // 追加一个新的子节点 + newnode := &handlerExtenderTree{ + path: path, + data: NewHandlerExtenderBase(), + } + he.childs = append(he.childs, newnode) + return newnode.data.RegisterExtender(path, i) +} + +// CreateHandler 函数基于路径创建多个对象处理函数。 +// +// 递归依次寻找子节点,然后返回时创建多个对象处理函数,如果子节点返回不为空就直接返回。 +func (he *handlerExtenderTree) CreateHandler(path string, data any) HandlerFuncs { + for _, child := range he.childs { + if strings.HasPrefix(path, child.path) { + hs := child.CreateHandler(path[len(child.path):], data) + if hs != nil { + return hs + } + break + } + } + + if he.data != nil { + return he.data.CreateHandler(path, data) + } + return nil +} + +// The listExtendHandlerNamesByPrefix method recursively adds path prefixes +// and returns extension function names. +// +// listExtendHandlerNamesByPrefix 方法递归添加路径前缀返回扩展函数名称。 +func (he *handlerExtenderTree) listExtendHandlerNamesByPrefix(prefix string) []string { + prefix += he.path + var names []string + if he.data != nil { + names = he.data.List() + if prefix != "" { + for i := range names { + names[i] = prefix + " " + names[i] + } + } + } + + for i := range he.childs { + names = append(names, he.childs[i].listExtendHandlerNamesByPrefix(prefix)...) + } + return names +} + +// List 方法返回全部注册的函数名称。 +func (he *handlerExtenderTree) List() []string { + return he.listExtendHandlerNamesByPrefix("") +} + +func (he *handlerExtenderTree) Metadata() any { + return MetadataHandlerExtender{ + Health: true, + Name: "eudore.handlerExtenderTree", + Extender: he.List(), + } +} + +func getFileLineFieldsVals(v reflect.Value) []any { + file, line := runtime.FuncForPC(v.Pointer()).FileLine(1) + return []any{file, line} +} + +// NewHandlerFunc 函数处理func()。 +func NewHandlerFunc(fn func()) HandlerFunc { + return func(Context) { + fn() + } +} + +// NewHandlerFuncContextError 函数处理func(Context) error返回的error处理。 +func NewHandlerFuncContextError(fn func(Context) error) HandlerFunc { + fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) + return func(ctx Context) { + err := fn(ctx) + if err != nil { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + } +} + +// NewHandlerFuncContextAnyError 函数处理func(Context) (T, error)返回数据渲染和error处理。 +func NewHandlerFuncContextAnyError(fn any) HandlerFunc { + v := reflect.ValueOf(fn) + iType := v.Type() + if iType.Kind() != reflect.Func || iType.NumIn() != 1 || iType.NumOut() != 2 || + iType.In(0) != typeContext || iType.Out(1) != typeError { + return nil + } + + fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) + return func(ctx Context) { + vals := v.Call([]reflect.Value{reflect.ValueOf(ctx)}) + err, _ := vals[1].Interface().(error) + if err == nil && ctx.Response().Size() == 0 { + err = ctx.Render(vals[0].Interface()) + } + if err != nil { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + } +} + +// NewHandlerFuncContextRender 函数处理func(Context) any返回数据渲染。 +func NewHandlerFuncContextRender(fn func(Context) any) HandlerFunc { + fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) + return func(ctx Context) { + data := fn(ctx) + if ctx.Response().Size() == 0 { + err := ctx.Render(data) + if err != nil { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + } + } +} + +// NewHandlerFuncContextRenderError 函数处理func(Context) (any, error)返回数据渲染和error处理。 +func NewHandlerFuncContextRenderError(fn func(Context) (any, error)) HandlerFunc { + fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) + return func(ctx Context) { + data, err := fn(ctx) + if err == nil && ctx.Response().Size() == 0 { + err = ctx.Render(data) + } + if err != nil { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + } +} + +// NewHandlerFuncError 函数处理func() error返回的error处理。 +func NewHandlerFuncError(fn func() error) HandlerFunc { + fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) + return func(ctx Context) { + err := fn() + if err != nil { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + } +} + +// NewHandlerFuncRPC function needs to pass in a function that returns a request for +// processing and is dynamically called by reflection. +// +// Function form: func (Context, Request) (Response, error) +// +// The types of Request and Response can be map or struct or pointer to struct. +// All 4 parameters need to exist, and the order cannot be changed. +// +// NewHandlerFuncRPC 函数需要传入一个函数,返回一个请求处理,通过反射来动态调用。 +// +// 函数形式: func(Context, Request) (Response, error) +// +// Request和Response的类型可以为map或结构体或者结构体的指针,4个参数需要全部存在,且不可调换顺序。 +func NewHandlerFuncRPC(fn any) HandlerFunc { + iType := reflect.TypeOf(fn) + v := reflect.ValueOf(fn) + if iType.Kind() != reflect.Func { + return nil + } + if iType.NumIn() != 2 || iType.In(0) != typeContext { + return nil + } + if iType.NumOut() != 2 || iType.Out(1) != typeError { + return nil + } + typeIn := iType.In(1) + kindIn := typeIn.Kind() + typenew := iType.In(1) + // 检查请求类型 + switch typeIn.Kind() { + case reflect.Ptr, reflect.Map, reflect.Slice, reflect.Struct: + default: + return nil + } + if typenew.Kind() == reflect.Ptr { + typenew = typenew.Elem() + } + + fineLineFieldsVals := getFileLineFieldsVals(v) + return func(ctx Context) { + // 创建请求参数并初始化 + req := reflect.New(typenew) + err := ctx.Bind(req.Interface()) + if err != nil { + ctx.Fatal(err) + return + } + if kindIn != reflect.Ptr { + req = req.Elem() + } + + // 反射调用执行函数。 + vals := v.Call([]reflect.Value{reflect.ValueOf(ctx), req}) + + // 检查函数执行err。 + err, ok := vals[1].Interface().(error) + if ok { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + return + } + + // 渲染返回的数据。 + err = ctx.Render(vals[0].Interface()) + if err != nil { + ctx.Fatal(err) + } + } +} + +// NewHandlerFuncRPCMap defines a fixed request and response to function processing of +// type map [string] interface {}. +// +// is a subset of NewRPCHandlerFunc and has type restrictions, +// but using map [string] interface {} to save requests does not use reflection. +// +// NewHandlerFuncRPCMap 定义了固定请求和响应为map[string]any类型的函数处理。 +// +// 是NewRPCHandlerFunc的一种子集,拥有类型限制,但是使用map[string]any保存请求没用使用反射。 +func NewHandlerFuncRPCMap(fn func(Context, map[string]any) (any, error)) HandlerFunc { + fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) + return func(ctx Context) { + req := make(map[string]any) + err := ctx.Bind(&req) + if err != nil { + ctx.Fatal(err) + return + } + resp, err := fn(ctx, req) + if err == nil && ctx.Response().Size() == 0 { + err = ctx.Render(resp) + } + if err != nil { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + } +} + +// NewHandlerFuncRender 函数处理func() any。 +func NewHandlerFuncRender(fn func() any) HandlerFunc { + fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) + return func(ctx Context) { + data := fn() + if ctx.Response().Size() == 0 { + err := ctx.Render(data) + if err != nil { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + } + } +} + +// NewHandlerFuncRenderError 函数处理func() (any, error)返回数据渲染和error处理。 +func NewHandlerFuncRenderError(fn func() (any, error)) HandlerFunc { + fineLineFieldsVals := getFileLineFieldsVals(reflect.ValueOf(fn)) + return func(ctx Context) { + data, err := fn() + if err == nil && ctx.Response().Size() == 0 { + err = ctx.Render(data) + } + if err != nil { + ctx.WithFields(fineLineFieldsKeys, fineLineFieldsVals).Fatal(err) + } + } +} + +// NewHandlerFuncString 函数处理func() string,然后指定函数生成的字符串。 +func NewHandlerFuncString(fn func() string) HandlerFunc { + return func(ctx Context) { + ctx.WriteString(fn()) + } +} + +type handlerHTTP interface { + HandleHTTP(Context) +} + +// NewHandlerHTTP 函数handlerHTTP接口转换成HandlerFunc。 +func NewHandlerHTTP(h handlerHTTP) HandlerFunc { + return h.HandleHTTP +} + +// NewHandlerHTTPFunc1 函数转换处理http.HandlerFunc类型。 +func NewHandlerHTTPFunc1(fn http.HandlerFunc) HandlerFunc { + return func(ctx Context) { + fn(ctx.Response(), ctx.Request()) + } +} + +// NewHandlerHTTPFunc2 函数转换处理func(http.ResponseWriter, *http.Request)类型。 +func NewHandlerHTTPFunc2(fn func(http.ResponseWriter, *http.Request)) HandlerFunc { + return func(ctx Context) { + fn(ctx.Response(), ctx.Request()) + } +} + +// NewHandlerNetHTTP 函数转换处理http.Handler对象。 +func NewHandlerHTTPHandler(h http.Handler) HandlerFunc { + clone, ok := h.(interface{ CloneHandler() http.Handler }) + if ok { + h = clone.CloneHandler() + } + return func(ctx Context) { + h.ServeHTTP(ctx.Response(), ctx.Request()) + } +} + +// NewHandlerStringer 函数处理fmt.Stringer接口类型转换成HandlerFunc。 +func NewHandlerStringer(fn fmt.Stringer) HandlerFunc { + return func(ctx Context) { + ctx.WriteString(fn.String()) + } +} diff --git a/logger.go b/logger.go index d90ff26..509e4a8 100644 --- a/logger.go +++ b/logger.go @@ -1,643 +1,506 @@ package eudore import ( - "bufio" "context" - "encoding" - "encoding/json" "fmt" - "io" "os" - "path/filepath" - "reflect" "runtime" + "sort" "strconv" "strings" "sync" "time" - "unicode/utf8" ) -// 定义日志级别 +// 枚举使用的日志级别。 const ( LoggerDebug LoggerLevel = iota LoggerInfo LoggerWarning LoggerError LoggerFatal + LoggerDiscard ) /* +Logger defines a log output interface to implement the following functions: + + Five-level log format output + Log entries with Fields attribute + json/text ordered formatted output + Custom processing Hook + Expression filter log + Initialize log processing + Standard output stream displays colored Level + Set file line information output + Log file soft link + log file rollover + log file cleanup + Logger 定义日志输出接口实现下列功能: + 五级日志格式化输出 日志条目带Fields属性 - json有序格式化输出 - 日志器初始化前日志处理 - 文件行信息输出 - 默认输入文件切割并软连接。 + json/text有序格式化输出 + 自定义处理Hook + 表达式过滤日志 + 初始化日志处理 + 标准输出流显示彩色Level + 设置文件行信息输出 + 日志文件软连接 + 日志文件滚动 + 日志文件清理 */ type Logger interface { - Debug(...interface{}) - Info(...interface{}) - Warning(...interface{}) - Error(...interface{}) - Fatal(...interface{}) - Debugf(string, ...interface{}) - Infof(string, ...interface{}) - Warningf(string, ...interface{}) - Errorf(string, ...interface{}) - Fatalf(string, ...interface{}) - WithField(string, interface{}) Logger - WithFields([]string, []interface{}) Logger + Debug(...any) + Info(...any) + Warning(...any) + Error(...any) + Fatal(...any) + Debugf(string, ...any) + Infof(string, ...any) + Warningf(string, ...any) + Errorf(string, ...any) + Fatalf(string, ...any) + WithField(string, any) Logger + WithFields([]string, []any) Logger GetLevel() LoggerLevel SetLevel(LoggerLevel) - Sync() error } -// LoggerLevel 定义日志级别 +// LoggerLevel 定义日志级别。 type LoggerLevel int32 -// loggerInitHandler 定义初始日志处理器必要接口,使用新日志处理器处理当前记录的全部日志。 -type loggerInitHandler interface { - NextHandler(Logger) +// loggerStd 定义日志默认实现条目信息。 +type loggerStd struct { + LoggerEntry + Handlers []LoggerHandler + Pool *sync.Pool + Logger bool + Depth int32 } -// LoggerStdConfig 定义loggerStd配置信息。 -// -// Writer 设置日志输出流,如果为空会使用Std和Path创建一个LoggerWriter。 -// -// Std 是否输出日志到os.Stdout标准输出流。 -// -// Path 指定文件输出路径,如果为空强制指定Std为true。 -// -// MaxSize 指定文件切割大小,需要Path中存在index字符串,用于替换成切割文件索引。 -// -// Link 如果非空会作为软连接的目标路径。 -// -// Level 日志输出级别。 -// -// TimeFormat 日志输出时间格式化格式。 -// -// FileLine 是否输出调用日志输出的函数和文件位置 -type LoggerStdConfig struct { - Writer LoggerWriter `json:"-" xml:"-" alias:"writer" description:"Logger output writer."` - Std bool `json:"std" xml:"std" alias:"std" description:"Is output to os.Stdout."` - Path string `json:"path" xml:"path" alias:"path" description:"Output logger file path."` - MaxSize uint64 `json:"maxsize" xml:"maxsize" alias:"maxsize" description:"Output file max size, 'Path' must contain 'index'."` - Link string `json:"link" xml:"link" alias:"link" description:"Output file link to path."` - Level LoggerLevel `json:"level" xml:"level" alias:"level" description:"Logger Output level."` - TimeFormat string `json:"timeformat" xml:"timeformat" alias:"timeformat" description:"Logger output timeFormat, default '2006-01-02 15:04:05'"` - FileLine bool `json:"fileline" xml:"fileline" alias:"fileline" description:"Is output file and line."` -} - -// LoggerStd 定义日志默认实现条目信息。 -type LoggerStd struct { - LoggerStdData - // enrty data - Time time.Time - Message string - Keys []string - Vals []interface{} - Buffer []byte - Timeformat string - // 日志标识 true是Logger false是Entry - Logger bool - Level LoggerLevel - Depth int -} - -// LoggerStdData 定义loggerStd的数据存储 -type LoggerStdData interface { - GetLogger() *LoggerStd - PutLogger(*LoggerStd) - Sync() error +// LoggerEntry 定义日志条目数据对象。 +type LoggerEntry struct { + Level LoggerLevel + Time time.Time + Message string + Keys []string + Vals []any + Buffer []byte } -// NewLoggerStd 创建一个标准日志处理器。 +// LoggerHandler 定义LoggerEntry处理方法 // -// 参数为一个eudore.LoggerStdConfig或map保存的创建配置,配置选项含义参考eudore.LoggerStdConfig说明。 -func NewLoggerStd(arg interface{}) Logger { - // 解析配置 - data, ok := arg.(LoggerStdData) - if !ok { - data = NewLoggerStdDataJSON(arg) - } - log := data.GetLogger() - log.Logger = true - return log -} - -// NewLoggerWithContext 方法从环境上下文ContextKeyLogger获取Logger,如果无法获取Logger返回DefaultLoggerNull对象。 -func NewLoggerWithContext(ctx context.Context) Logger { - log, ok := ctx.Value(ContextKeyLogger).(Logger) - if ok { - return log - } - return DefaultLoggerNull -} - -// NewLoggerInit The initial log processor only records logs, and gets a new Logger to process logs when Unmount. +// HandlerPriority 方法返回Handler处理顺序,小值优先。 // -// NewLoggerInit 初始日志处理器仅记录日志,在Unmount时获取新Logger处理日志. -func NewLoggerInit() Logger { - return NewLoggerStd(&loggerStdDataInit{}) +// HandlerEntry 方法处理Entry内容,设置Level=LoggerDiscard后结束后续处理。 +type LoggerHandler interface { + HandlerPriority() int + HandlerEntry(*LoggerEntry) +} + +// LoggerConfig 定义loggerStd配置信息。 +type LoggerConfig struct { + // 设置额外的LoggerHandler,和配置初始化创建的Handlers排序后处理LoggerEntry。 + Handlers []LoggerHandler `alias:"handlers" json:"-" xml:"-" yaml:"-"` + // 设置日志输出级别。 + Level LoggerLevel `alias:"level" json:"level" xml:"level" yaml:"level"` + // 是否记录调用者信息。 + Caller bool `alias:"caller" json:"caller" xml:"caller" yaml:"caller"` + // 设置Entry输出格式,默认值为json, + // 如果为json/text启用NewLoggerFormatterJSON/NewLoggerFormatterText。 + Formatter string `alias:"formater" json:"formater" xml:"formater" yaml:"formater"` + // 设置日志时间输出格式,默认值为DefaultLoggerFormatterFormatTime或time.RFC3339。 + TimeFormat string `alias:"timeformat" json:"timeformat" xml:"timeformat" yaml:"timeformat"` + // 设置Entry过滤规则;如果非空启用NewLoggerHookFilter。 + HookFilter [][]string `alias:"hoolfilter" json:"hoolfilter" xml:"hoolfilter" yaml:"hoolfilter"` + // 是否处理Fatal级别日志,调用应用结束方法;如果为true启用NewLoggerHookMeta。 + HookFatal bool `alias:"hookfatal" json:"hookfatal" xml:"hookfatal" yaml:"hookfatal"` + // 是否采集Meta信息,记录日志count、size;如果为true启用NewLoggerHookFatal。 + HookMeta bool `alias:"hookmeta" json:"hookmeta" xml:"hookmeta" yaml:"hookmeta"` + // 是否输出日志到os.Stdout标准输出流;如果存在Env EnvEudoreDaemonEnable时会强制修改为false; + // 如果为true启动NewLoggerWriterStdout。 + Stdout bool `alias:"stdout" json:"stdout" xml:"stdout" yaml:"stdout"` + // 是否输出日志时使用彩色Level,默认在windows系统下禁用。 + StdColor bool `alias:"stdcolor" json:"stdcolor" xml:"stdcolor" yaml:"stdcolor"` + // 设置日志文件输出路径;如果非空启用NewLoggerWriterFile, + // 如果Path包含关键字yyyy/mm/dd/hh或MaxSize非0则改为启用NewLoggerWriterRotate。 + Path string `alias:"path" json:"path" xml:"path" yaml:"path" description:"Output file path."` + // 设置日志文件滚动size,在文件名后缀之前添加索引值。 + MaxSize uint64 `alias:"maxsize" json:"maxsize" xml:"maxsize" yaml:"maxsize" description:"roatte file max size"` + // 设置日志文件最多保留天数,如果非0使用hookFileRecycle。 + MaxAge int `alias:"maxage" json:"maxage" xml:"maxage" yaml:"maxage"` + // 设置日志文件最多保留数量,如果非0使用hookFileRecycle。 + MaxCount int `alias:"maxcount" json:"maxcount" xml:"maxcount" yaml:"maxcount"` + // 设置日志文件软链接名称,如果非空使用hookFileLink。 + Link string `alias:"link" json:"link" xml:"link" yaml:"link" description:"Output file link to path."` +} + +type MetadataLogger struct { + Health bool `alias:"health" json:"health" xml:"health" yaml:"health"` + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` + Count [5]uint64 `alias:"count" json:"count" xml:"count" yaml:"count"` + Size uint64 `alias:"size" json:"size" xml:"size" yaml:"size"` + SizeFormat string `alias:"sizeformat" json:"sizeformat" xml:"sizeformat" yaml:"sizeformat"` } -// NewLoggerNull 定义空日志输出,丢弃所有日志。 -func NewLoggerNull() Logger { - return NewLoggerStd(&loggerStdDataNull{}) -} +/* +NewLogger 创建一个标准日志处理器。 -// NewLoggerStdDataJSON 函数创建一个LoggerStd的JSON数据处理器。 -func NewLoggerStdDataJSON(arg interface{}) LoggerStdData { - config := &LoggerStdConfig{ - TimeFormat: DefaultLoggerTimeFormat, - } - ConvertTo(arg, config) - logdepath := 3 - if config.FileLine { - logdepath |= 0x100 - } - if config.Writer == nil { - var err error - config.Path = strings.TrimSpace(config.Path) - config.Writer, err = NewLoggerWriterRotate(config.Path, config.Std, config.MaxSize, newLoggerLinkName(config.Link)) - if err != nil { - panic(err) - } - } +默认配置: - data := &loggerStdDataJSON{ - LoggerWriter: config.Writer, + &LoggerConfig{ + Stdout: true, + StdColor: DefaultLoggerEnableStdColor, + HookFatal: DefaultLoggerEnableHookFatal, + HookMeta: DefaultLoggerEnableHookMeta, } - data.Pool.New = func() interface{} { - return &LoggerStd{ - LoggerStdData: data, - Level: config.Level, - Buffer: make([]byte, 0, 2048), - Keys: make([]string, 0, 4), - Vals: make([]interface{}, 0, 4), - Timeformat: config.TimeFormat, - Depth: logdepath, +*/ +func NewLogger(config *LoggerConfig) Logger { + if config == nil { + config = &LoggerConfig{ + Stdout: true, + StdColor: DefaultLoggerEnableStdColor, + HookFatal: DefaultLoggerEnableHookFatal, + HookMeta: DefaultLoggerEnableHookMeta, } } - return data -} -type loggerStdDataJSON struct { - LoggerWriter - sync.Mutex - sync.Pool - done chan struct{} -} - -// Mount 方法启动周期Sync,每80ms执行一次。 -func (data *loggerStdDataJSON) Mount(ctx context.Context) { - data.Lock() - defer data.Unlock() - if data.done == nil { - data.done = make(chan struct{}) - } - go func() { - ticker := time.NewTicker(DefaultLoggerSyncDuration) - for { - select { - case <-data.done: - ticker.Stop() - close(data.done) - data.done = nil - return - case <-ticker.C: - data.Sync() - } + handlers := config.getHandlers() + pool := &sync.Pool{} + pool.New = func() any { + return &loggerStd{ + Pool: pool, + Handlers: handlers, + LoggerEntry: LoggerEntry{ + Level: config.Level, + Keys: make([]string, 0, DefaultLoggerEntryFieldsLength), + Vals: make([]any, 0, DefaultLoggerEntryFieldsLength), + Buffer: make([]byte, 0, DefaultLoggerEntryBufferLength), + }, } - }() -} - -// Unmount 方法关闭周期Sync。 -func (data *loggerStdDataJSON) Unmount(ctx context.Context) { - data.Lock() - defer data.Unlock() - if data.done != nil { - data.done <- struct{}{} } - data.LoggerWriter.Sync() -} -func (data *loggerStdDataJSON) Sync() error { - data.Lock() - defer data.Unlock() - return data.LoggerWriter.Sync() -} - -func (data *loggerStdDataJSON) GetLogger() *LoggerStd { - return data.Get().(*LoggerStd) -} - -func (data *loggerStdDataJSON) PutLogger(entry *LoggerStd) { - if len(entry.Message) > 0 || len(entry.Keys) > 0 { - switch entry.Depth >> 8 { - case 1: - name, file, line := logFormatNameFileLine(entry.Depth & 0xff) - entry.Keys = append(entry.Keys, "name", "file", "line") - entry.Vals = append(entry.Vals, name, file, line) - case 2, 3: - entry.Keys = append(entry.Keys, "stack") - entry.Vals = append(entry.Vals, GetPanicStack(entry.Depth&0xff+1)) - } - if len(entry.Keys) > len(entry.Vals) { - entry.Keys = entry.Keys[0:len(entry.Vals)] - entry.WithField("loggererr", "LoggerStd.loggerStdDataJSON: The number of field keys and values are not equal") - } - loggerEntryStdFormat(entry) - data.Lock() - data.Write(entry.Buffer) - data.Unlock() - entry.Message = "" - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - entry.Buffer = entry.Buffer[0:0] + log := pool.New().(*loggerStd) + log.Logger = true + log.Depth = 4 + if config.Caller { + log.Depth |= 0x100 } - data.Put(entry) + return log } -type loggerStdDataInit struct { - sync.Mutex - Data []*LoggerStd -} +func (config *LoggerConfig) getHandlers() []LoggerHandler { + config.TimeFormat = GetAnyByString(config.TimeFormat, DefaultLoggerFormatterFormatTime, time.RFC3339) + config.Formatter = GetAnyByString(config.Formatter, DefaultLoggerFormatter, "json") + config.Stdout = config.Stdout && !GetAnyByString[bool](os.Getenv(EnvEudoreDaemonEnable)) + config.StdColor = config.StdColor && (runtime.GOOS != "windows" || DefaultLoggerWriterStdoutWindowsColor) + config.Path = strings.TrimSpace(config.Path) -func (data *loggerStdDataInit) GetLogger() *LoggerStd { - return &LoggerStd{ - LoggerStdData: data, + hs := config.Handlers + // formatter + switch strings.ToLower(config.Formatter) { + case "json": + hs = append(hs, NewLoggerFormatterJSON(config.TimeFormat)) + case "text": + hs = append(hs, NewLoggerFormatterText(config.TimeFormat)) } -} -func (data *loggerStdDataInit) PutLogger(entry *LoggerStd) { - entry.Time = time.Now() - data.Lock() - data.Data = append(data.Data, entry) - data.Unlock() -} -// Unmount 方法获取ContextKeyLogger.(Logger)接受Init存储的日志。 -func (data *loggerStdDataInit) Unmount(ctx context.Context) { - data.Lock() - defer data.Unlock() - logger, _ := ctx.Value(ContextKeyLogger).(Logger) - if logger == nil { - logger = NewLoggerStd(nil) + // hook + if len(config.HookFilter) > 0 { + hs = append(hs, NewLoggerHookFilter(config.HookFilter)) + } + if config.HookMeta { + hs = append(hs, NewLoggerHookMeta()) + } + if config.HookFatal { + hs = append(hs, NewLoggerHookFatal(nil)) } - logger = logger.WithField("depth", "disable").WithField("logger", true) - for _, data := range data.Data { - entry := logger.WithField("time", data.Time) - for i := range data.Keys { - entry = entry.WithField(data.Keys[i], data.Vals[i]) + // writer-stdout + if config.Stdout { + hs = append(hs, NewLoggerWriterStdout(config.StdColor)) + } + // writer-rotate + if config.Path != "" { + var hook []func(string, string) + if config.Link != "" { + hook = append(hook, hookFileLink(config.Link)) } - switch data.Level { - case LoggerDebug: - entry.Debug(data.Message) - case LoggerInfo: - entry.Info(data.Message) - case LoggerWarning: - entry.Warning(data.Message) - case LoggerError: - entry.Error(data.Message) - case LoggerFatal: - entry.Fatal(data.Message) + if config.MaxAge > 0 || config.MaxCount > 1 { + hook = append(hook, hookFileRecycle(config.MaxAge, config.MaxCount)) } + h, err := NewLoggerWriterRotate(config.Path, config.MaxSize, hook...) + if err != nil { + panic(err) + } + hs = append(hs, h) } - data.Data = data.Data[0:0] - logger.Sync() -} -func (data *loggerStdDataInit) Sync() error { - return nil -} - -type loggerStdDataNull struct{} - -func (data *loggerStdDataNull) GetLogger() *LoggerStd { - return &LoggerStd{ - LoggerStdData: data, - } + sort.Slice(hs, func(i, j int) bool { + return hs[i].HandlerPriority() < hs[j].HandlerPriority() + }) + return hs } -func (data *loggerStdDataNull) PutLogger(entry *LoggerStd) { +// NewLoggerInit The initial log processor only records logs, and gets a new Logger to process logs when Unmount. +// +// If the subsequent Logger is not set after LoggerInit is set, App.Run() must be called to release the log in LoggerInit. +// +// NewLoggerInit 初始日志处理器仅记录日志,在Unmount时获取新Logger处理日志. +// +// 在设置LoggerInit后未设置后续Logger,必须调用App.Run()将LoggerInit内日志释放出来。 +func NewLoggerInit() Logger { + return NewLogger(&LoggerConfig{ + Handlers: []LoggerHandler{&loggerHandlerInit{}}, + Formatter: "disable", + HookMeta: true, + }) } -func (data *loggerStdDataNull) Sync() error { - return nil +// NewLoggerNull 定义空日志输出,丢弃所有日志。 +func NewLoggerNull() Logger { + return NewLogger(&LoggerConfig{ + Level: LoggerDiscard, + Formatter: "disable", + }) } -func (entry *LoggerStd) getEntry() *LoggerStd { - newentry := entry.LoggerStdData.GetLogger() - newentry.Level = entry.Level - newentry.Depth = entry.Depth - if len(entry.Keys) != 0 { - newentry.Keys = append(newentry.Keys, entry.Keys...) - newentry.Vals = append(newentry.Vals, entry.Vals...) +// NewLoggerWithContext 方法从环境上下文ContextKeyLogger获取Logger,如果无法获取Logger返回DefaultLoggerNull对象。 +func NewLoggerWithContext(ctx context.Context) Logger { + log, ok := ctx.Value(ContextKeyLogger).(Logger) + if ok { + return log } - return newentry + return DefaultLoggerNull } // Mount 方法使LoggerStd挂载上下文,上下文传递给LoggerStdData。 -func (entry *LoggerStd) Mount(ctx context.Context) { - withMount(ctx, entry.LoggerStdData) +func (log *loggerStd) Mount(ctx context.Context) { + for i := range log.Handlers { + anyMount(ctx, log.Handlers[i]) + } } // Unmount 方法使LoggerStd卸载上下文,上下文传递给LoggerStdData。 -func (entry *LoggerStd) Unmount(ctx context.Context) { - withUnmount(ctx, entry.LoggerStdData) +func (log *loggerStd) Unmount(ctx context.Context) { + for i := len(log.Handlers) - 1; i > -1; i-- { + anyUnmount(ctx, log.Handlers[i]) + } } -// Metadata 方法从LoggerStdData获取元数据返回。 -func (entry *LoggerStd) Metadata() interface{} { - return withMetadata(entry.LoggerStdData) +// Metadata 方法从Handlers查找到第一个anyMetadata对象返回meta。 +func (log *loggerStd) Metadata() any { + for i := range log.Handlers { + meta := anyMetadata(log.Handlers[i]) + if meta != nil { + return meta + } + } + return nil } // GetLevel 方法获取当前日志输出级别,判断级别取消日志生成。 -func (entry *LoggerStd) GetLevel() LoggerLevel { - return entry.Level +func (log *loggerStd) GetLevel() LoggerLevel { + return log.Level } // SetLevel 方法设置当前日志输出级别。 -func (entry *LoggerStd) SetLevel(level LoggerLevel) { - entry.Level = level -} - -// Sync 方法将缓冲写入到输出流。 -func (entry *LoggerStd) Sync() error { - return entry.LoggerStdData.Sync() +func (log *loggerStd) SetLevel(level LoggerLevel) { + log.Level = level } // Debug 方法条目输出Debug级别日志。 -func (entry *LoggerStd) Debug(args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - if entry.Level < 1 { - entry.Level = 0 - entry.Message = fmt.Sprintln(args...) - entry.Message = entry.Message[:len(entry.Message)-1] - } else { - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - } - entry.LoggerStdData.PutLogger(entry) +func (log *loggerStd) Debug(args ...any) { + log.format(LoggerDebug, args...) } // Info 方法条目输出Info级别日志。 -func (entry *LoggerStd) Info(args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - if entry.Level < 2 { - entry.Level = 1 - entry.Message = fmt.Sprintln(args...) - entry.Message = entry.Message[:len(entry.Message)-1] - } else { - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - } - entry.LoggerStdData.PutLogger(entry) +func (log *loggerStd) Info(args ...any) { + log.format(LoggerInfo, args...) } // Warning 方法条目输出Warning级别日志。 -func (entry *LoggerStd) Warning(args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - if entry.Level < 3 { - entry.Level = 2 - entry.Message = fmt.Sprintln(args...) - entry.Message = entry.Message[:len(entry.Message)-1] - } else { - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - } - entry.LoggerStdData.PutLogger(entry) +func (log *loggerStd) Warning(args ...any) { + log.format(LoggerWarning, args...) } // Error 方法条目输出Error级别日志。 -func (entry *LoggerStd) Error(args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - if entry.Level < 4 { - entry.Level = 3 - entry.Message = fmt.Sprintln(args...) - entry.Message = entry.Message[:len(entry.Message)-1] - } else { - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - } - entry.LoggerStdData.PutLogger(entry) +func (log *loggerStd) Error(args ...any) { + log.format(LoggerError, args...) } // Fatal 方法条目输出Fatal级别日志。 -func (entry *LoggerStd) Fatal(args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - entry.Level = 4 - entry.Message = fmt.Sprintln(args...) - entry.Message = entry.Message[:len(entry.Message)-1] - entry.LoggerStdData.PutLogger(entry) +func (log *loggerStd) Fatal(args ...any) { + log.format(LoggerFatal, args...) } -// Debugf 方法格式化写入流Debug级别日志 -func (entry *LoggerStd) Debugf(format string, args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - if entry.Level < LoggerInfo { - entry.Level = LoggerDebug - entry.Message = fmt.Sprintf(format, args...) - } else { - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - } - entry.LoggerStdData.PutLogger(entry) +// Debugf 方法格式化写入流Debug级别日志。 +func (log *loggerStd) Debugf(format string, args ...any) { + log.formatf(LoggerDebug, format, args...) } -// Infof 方法格式写入流出Info级别日志 -func (entry *LoggerStd) Infof(format string, args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - if entry.Level < LoggerWarning { - entry.Level = LoggerInfo - entry.Message = fmt.Sprintf(format, args...) - } else { - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - } - entry.LoggerStdData.PutLogger(entry) +// Infof 方法格式写入流出Info级别日志。 +func (log *loggerStd) Infof(format string, args ...any) { + log.formatf(LoggerInfo, format, args...) } -// Warningf 方法格式化输出写入流Warning级别日志 -func (entry *LoggerStd) Warningf(format string, args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - if entry.Level < LoggerError { - entry.Level = LoggerWarning - entry.Message = fmt.Sprintf(format, args...) - } else { - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - } - entry.LoggerStdData.PutLogger(entry) +// Warningf 方法格式化输出写入流Warning级别日志。 +func (log *loggerStd) Warningf(format string, args ...any) { + log.formatf(LoggerWarning, format, args...) } -// Errorf 方法格式化写入流Error级别日志 -func (entry *LoggerStd) Errorf(format string, args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - if entry.Level < LoggerFatal { - entry.Level = LoggerError - entry.Message = fmt.Sprintf(format, args...) - } else { - entry.Keys = entry.Keys[0:0] - entry.Vals = entry.Vals[0:0] - } - entry.LoggerStdData.PutLogger(entry) +// Errorf 方法格式化写入流Error级别日志。 +func (log *loggerStd) Errorf(format string, args ...any) { + log.formatf(LoggerError, format, args...) } -// Fatalf 方法格式化写入流Fatal级别日志 -func (entry *LoggerStd) Fatalf(format string, args ...interface{}) { - if entry.Logger { - entry = entry.getEntry() - } - entry.Level = 4 - entry.Message = fmt.Sprintf(format, args...) - entry.LoggerStdData.PutLogger(entry) +// Fatalf 方法格式化写入流Fatal级别日志。 +func (log *loggerStd) Fatalf(format string, args ...any) { + log.formatf(LoggerFatal, format, args...) } -// WithFields 方法一次设置多个条目属性。 -// -// 如果key和val同时为nil会返回Logger的深拷贝对象。 -// -// WithFields不会设置Field属性。 -func (entry *LoggerStd) WithFields(key []string, value []interface{}) Logger { - if entry.Logger { - entry = entry.getEntry() +// WithFields 方法一次设置多个属性,但是不会设置Field属性。 +func (log *loggerStd) WithFields(key []string, value []any) Logger { + if log.Logger { + log = log.getLogger() } - entry.Keys = append(entry.Keys, key...) - entry.Vals = append(entry.Vals, value...) - return entry + log.Keys = append(log.Keys, key...) + log.Vals = append(log.Vals, value...) + return log } -// WithField 方法设置一个日志属性。 -// -// 如果key为"context"值类型为context.Context,设置该值用于传递自定义信息。 +// WithField 方法设置一个日志属性,指定key时会执行特色行为。 // -// 如果key为"depth"值类型为int,设置日志调用堆栈增删层数。 +// 如果key为"logger"值为bool(true),将LoggerEntry设置为Logger。 // -// 如果key为"depth"值类型为string值"enable"或"disable",启用或关闭日志调用位置输出。 +// 如果key为"depth"值类型为int,设置日志调用堆栈增删层数; +// 如果key为"depth"值类型为string值"enable"或"disable",启用或关闭日志调用位置输出; +// 并增加key: file/func/stack,如果使用到相关key需要先禁用depth。 // // 如果key为"time"值类型为time.time,设置日志输出的时间属性。 -func (entry *LoggerStd) WithField(key string, value interface{}) Logger { - if entry.Logger { - entry = entry.getEntry() +func (log *loggerStd) WithField(key string, value any) Logger { + if log.Logger { + log = log.getLogger() } switch key { - case "context": - val, ok := value.(context.Context) - if ok { - for i := range entry.Keys { - if entry.Keys[i] == "context" { - entry.Vals[i] = val - return entry - } - } + case "logger": + val, ok := value.(bool) + if ok && val { + log.Logger = true + return log } - case "depth": - return entry.withFieldDepth(key, value) + case ParamDepth: + return log.withFieldDepth(key, value) case "time": val, ok := value.(time.Time) if ok { - entry.Time = val - return entry - } - case "logger": - val, ok := value.(bool) - if ok && val { - entry.Logger = true - return entry + log.Time = val + return log } } - entry.Keys = append(entry.Keys, key) - entry.Vals = append(entry.Vals, value) - return entry + log.Keys = append(log.Keys, key) + log.Vals = append(log.Vals, value) + return log } -// withFieldDepth 方法处理withDepth属性,cost 67 可内联。 -func (entry *LoggerStd) withFieldDepth(key string, value interface{}) Logger { - val, ok := value.(int) - if ok { - entry.Depth += val - return entry - } - vals, ok := value.(string) - if ok { - switch vals { - case "stack": - entry.Depth |= 0x200 +// withFieldDepth 方法处理withDepth属性,cost 53 可内联。 +func (log *loggerStd) withFieldDepth(key string, value any) Logger { + switch val := value.(type) { + case int: + log.Depth += int32(val) + case string: + switch val { case "enable": - entry.Depth |= 0x100 + log.Depth |= 0x100 + case "stack": + log.Depth |= 0x200 case "disable": - entry.Depth &^= 0x300 + log.Depth &^= 0x300 } - return entry + default: + log.Keys = append(log.Keys, key) + log.Vals = append(log.Vals, value) + } + return log +} + +func (log *loggerStd) getLogger() *loggerStd { + entry := log.Pool.Get().(*loggerStd) + entry.Time = time.Now() + entry.Message = "" + entry.Keys = entry.Keys[0:0] + entry.Vals = entry.Vals[0:0] + entry.Buffer = entry.Buffer[0:0] + entry.Level = log.Level + entry.Depth = log.Depth + if len(log.Keys) > 0 { + entry.Keys = append(entry.Keys, log.Keys...) + entry.Vals = append(entry.Vals, log.Vals...) } - entry.Keys = append(entry.Keys, key) - entry.Vals = append(entry.Vals, value) return entry } -var ( - loggerlevels = [][]byte{[]byte("DEBUG"), []byte("INFO"), []byte("WARIRNG"), []byte("ERROR"), []byte("FATAL")} - loggerpart1 = []byte(`{"time":"`) - loggerpart2 = []byte(`","level":"`) - loggerpart3 = []byte(`,"message":"`) - loggerpart4 = []byte("\"}\n") - loggerpart5 = []byte("}\n") - _hex = "0123456789abcdef" -) +func (log *loggerStd) format(level LoggerLevel, args ...any) { + if log.Level <= level { + if log.Logger { + log = log.getLogger() + } + log.Level = level + log.Message = fmt.Sprintln(args...) + log.Message = log.Message[:len(log.Message)-1] + log.handler() + log.Pool.Put(log) + } +} -func loggerEntryStdFormat(entry *LoggerStd) { - t := entry.Time - if t.IsZero() { - t = time.Now() +func (log *loggerStd) formatf(level LoggerLevel, format string, args ...any) { + if log.Level <= level { + if log.Logger { + log = log.getLogger() + } + log.Level = level + log.Message = fmt.Sprintf(format, args...) + log.handler() + log.Pool.Put(log) } - entry.Buffer = append(entry.Buffer, loggerpart1...) - entry.Buffer = t.AppendFormat(entry.Buffer, entry.Timeformat) - entry.Buffer = append(entry.Buffer, loggerpart2...) - entry.Buffer = append(entry.Buffer, loggerlevels[entry.Level]...) - entry.Buffer = append(entry.Buffer, '"') +} - for i := range entry.Keys { - entry.Buffer = append(entry.Buffer, ',') - entry.Buffer = append(entry.Buffer, '"') - entry.Buffer = append(entry.Buffer, entry.Keys[i]...) - entry.Buffer = append(entry.Buffer, '"', ':') - loggerFormatWriteValue(entry, entry.Vals[i]) +func (log *loggerStd) handler() { + if len(log.Keys) > len(log.Vals) { + log.Keys = log.Keys[0:len(log.Vals)] + log.Keys = append(log.Keys, "loggererr") + log.Vals = append(log.Vals, "LoggerStd: The number of field keys and values are not equal") } - if len(entry.Message) > 0 { - entry.Buffer = append(entry.Buffer, loggerpart3...) - loggerFormatWriteString(entry, entry.Message) - entry.Buffer = append(entry.Buffer, loggerpart4...) - } else { - entry.Buffer = append(entry.Buffer, loggerpart5...) + if len(log.Message) > 0 || len(log.Keys) > 0 { + switch log.Depth >> 8 { + case 1: + fname, file := GetCallerFuncFile(int(log.Depth) & 0xff) + if fname != "" { + log.Keys = append(log.Keys, "func") + log.Vals = append(log.Vals, fname) + } + if file != "" { + log.Keys = append(log.Keys, "file") + log.Vals = append(log.Vals, file) + } + case 2, 3: + log.Keys = append(log.Keys, "stack") + log.Vals = append(log.Vals, GetCallerStacks(int(log.Depth&0xff)+1)) + } + for _, h := range log.Handlers { + if log.Level < LoggerDiscard { + h.HandlerEntry(&log.LoggerEntry) + } + } } } // String 方法实现ftm.Stringer接口,格式化输出日志级别。 func (l LoggerLevel) String() string { - return DefaultLoggerLevelString[l] + return DefaultLoggerLevelStrings[l] } // MarshalText 方法实现encoding.TextMarshaler接口,用于编码日志级别。 @@ -648,7 +511,7 @@ func (l LoggerLevel) MarshalText() ([]byte, error) { // UnmarshalText 方法实现encoding.TextUnmarshaler接口,用于解码日志级别。 func (l *LoggerLevel) UnmarshalText(text []byte) error { str := strings.ToUpper(string(text)) - for i, s := range DefaultLoggerLevelString { + for i, s := range DefaultLoggerLevelStrings { if s == str { *l = LoggerLevel(i) return nil @@ -662,435 +525,52 @@ func (l *LoggerLevel) UnmarshalText(text []byte) error { return ErrLoggerLevelUnmarshalText } -// logFormatNameFileLine 函数获得调用的文件位置和函数名称。 -// -// 文件位置会从第一个src后开始截取,处理gopath下文件位置。 -func logFormatNameFileLine(depth int) (string, string, int) { - ptr, file, line, ok := runtime.Caller(depth) - if ok { - slash := strings.Index(file, "src") - if slash >= 0 { - file = file[slash+4:] - } - return runtime.FuncForPC(ptr).Name(), file, line - } - return "", "???", 0 -} - -// GetPanicStack 函数返回panic栈信息。 -func GetPanicStack(depth int) []string { - pc := make([]uintptr, DefaultLoggerDepth) - n := runtime.Callers(depth, pc) - if n == 0 { - return nil - } +var works = [...]string{"/pkg/mod/", "/src/"} - stack := make([]string, 0, n) - frames := runtime.CallersFrames(pc[:n]) - frame, more := frames.Next() - for more { - pos := strings.Index(frame.File, "src") - if pos >= 0 { - frame.File = frame.File[pos+4:] - } - pos = strings.LastIndex(frame.Function, "/") - if pos >= 0 { - frame.Function = frame.Function[pos+1:] +func trimFileName(name string) string { + for _, w := range works { + pos := strings.Index(name, w) + if pos != -1 { + name = name[pos+len(w):] } - stack = append(stack, fmt.Sprintf("%s:%d %s", frame.File, frame.Line, frame.Function)) - - frame, more = frames.Next() } - return stack + return name } -// WriteValue 方法写入值。 -func loggerFormatWriteValue(entry *LoggerStd, value interface{}) { - iValue := reflect.ValueOf(value) - loggerFormatWriteReflect(entry, iValue) -} - -// loggerFormatWriteReflect 方法写入值。 -func loggerFormatWriteReflect(entry *LoggerStd, iValue reflect.Value) { - if loggerFormatWriteReflectFace(entry, iValue) { - return - } - // 写入类型 - switch iValue.Kind() { - case reflect.Bool: - entry.Buffer = strconv.AppendBool(entry.Buffer, iValue.Bool()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - entry.Buffer = strconv.AppendInt(entry.Buffer, iValue.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - entry.Buffer = strconv.AppendUint(entry.Buffer, iValue.Uint(), 10) - case reflect.Float32, reflect.Float64: - entry.Buffer = strconv.AppendFloat(entry.Buffer, iValue.Float(), 'f', -1, 64) - case reflect.Complex64, reflect.Complex128: - val := iValue.Complex() - r, i := float64(real(val)), float64(imag(val)) - entry.Buffer = append(entry.Buffer, '"') - entry.Buffer = strconv.AppendFloat(entry.Buffer, r, 'f', -1, 64) - entry.Buffer = append(entry.Buffer, '+') - entry.Buffer = strconv.AppendFloat(entry.Buffer, i, 'f', -1, 64) - entry.Buffer = append(entry.Buffer, 'i') - entry.Buffer = append(entry.Buffer, '"') - case reflect.String: - entry.Buffer = append(entry.Buffer, '"') - loggerFormatWriteString(entry, iValue.String()) - entry.Buffer = append(entry.Buffer, '"') - case reflect.Ptr, reflect.Interface: - loggerFormatWriteReflect(entry, iValue.Elem()) - case reflect.Func, reflect.Chan, reflect.UnsafePointer: - entry.Buffer = append(entry.Buffer, '0', 'x') - entry.Buffer = strconv.AppendUint(entry.Buffer, uint64(iValue.Pointer()), 16) - case reflect.Map: - loggerFormatWriteReflectMap(entry, iValue) - case reflect.Array, reflect.Slice: - loggerFormatWriteReflectSlice(entry, iValue) - case reflect.Struct: - loggerFormatWriteReflectStruct(entry, iValue) - } -} - -func loggerFormatWriteReflectFace(entry *LoggerStd, iValue reflect.Value) bool { - switch iValue.Kind() { - case reflect.Map, reflect.Slice: - if iValue.IsNil() { - entry.Buffer = append(entry.Buffer, 'n', 'u', 'l', 'l') - return true - } - case reflect.Ptr, reflect.Func, reflect.Chan: - if iValue.IsNil() { - entry.Buffer = append(entry.Buffer, 'n', 'u', 'l', 'l') - return true - } - case reflect.Interface: - return false - case reflect.Invalid: - entry.Buffer = append(entry.Buffer, '"', '<', 'I', 'n', 'v', 'a', 'l', 'i', 'd', - ' ', 'V', 'a', 'l', 'u', 'e', '>', '"') - return true - } - // 检查接口 - switch val := iValue.Interface().(type) { - case json.Marshaler: - body, err := val.MarshalJSON() - if err == nil { - entry.Buffer = append(entry.Buffer, body...) - } else { - entry.Buffer = append(entry.Buffer, '"') - loggerFormatWriteString(entry, err.Error()) - entry.Buffer = append(entry.Buffer, '"') - } - case encoding.TextMarshaler: - body, err := val.MarshalText() - entry.Buffer = append(entry.Buffer, '"') - if err == nil { - loggerFormatWriteBytes(entry, body) - } else { - loggerFormatWriteString(entry, err.Error()) - } - entry.Buffer = append(entry.Buffer, '"') - case fmt.Stringer: - entry.Buffer = append(entry.Buffer, '"') - loggerFormatWriteString(entry, val.String()) - entry.Buffer = append(entry.Buffer, '"') - case error: - entry.Buffer = append(entry.Buffer, '"') - loggerFormatWriteString(entry, val.Error()) - entry.Buffer = append(entry.Buffer, '"') - default: - return false - } - return true -} - -func loggerFormatWriteReflectStruct(entry *LoggerStd, iValue reflect.Value) { - entry.Buffer = append(entry.Buffer, '{') - pos := len(entry.Buffer) - iType := iValue.Type() - for i := 0; i < iValue.NumField(); i++ { - if iValue.Field(i).CanInterface() { - name, omit := split2byte(iType.Field(i).Tag.Get("json"), ',') - if name == "-" || (omit == "omitempty" && iValue.Field(i).IsZero()) { - continue - } - if name == "" { - name = iType.Field(i).Name - } - entry.Buffer = append(entry.Buffer, '"') - loggerFormatWriteString(entry, name) - entry.Buffer = append(entry.Buffer, '"', ':') - loggerFormatWriteReflect(entry, iValue.Field(i)) - entry.Buffer = append(entry.Buffer, ',') - } - } - if pos == len(entry.Buffer) { - entry.Buffer = append(entry.Buffer, '}') - } else { - entry.Buffer[len(entry.Buffer)-1] = '}' - } -} - -func loggerFormatWriteReflectSlice(entry *LoggerStd, iValue reflect.Value) { - if iValue.Len() == 0 { - entry.Buffer = append(entry.Buffer, '[', ']') - return - } - entry.Buffer = append(entry.Buffer, '[') - for i := 0; i < iValue.Len(); i++ { - loggerFormatWriteReflect(entry, iValue.Index(i)) - entry.Buffer = append(entry.Buffer, ',') - } - entry.Buffer[len(entry.Buffer)-1] = ']' -} - -func loggerFormatWriteReflectMap(entry *LoggerStd, iValue reflect.Value) { - if iValue.Len() == 0 { - entry.Buffer = append(entry.Buffer, '{', '}') - return - } - - entry.Buffer = append(entry.Buffer, '{') - for _, key := range iValue.MapKeys() { - loggerFormatWriteReflect(entry, key) - entry.Buffer = append(entry.Buffer, ':') - loggerFormatWriteReflect(entry, iValue.MapIndex(key)) - entry.Buffer = append(entry.Buffer, ',') +func trimFuncName(name string) string { + pos := strings.LastIndexByte(name, '/') + if pos != -1 { + name = name[pos+1:] } - entry.Buffer[len(entry.Buffer)-1] = '}' + return name } -// loggerFormatWriteString 方法安全写入字符串。 -func loggerFormatWriteString(entry *LoggerStd, s string) { - for i := 0; i < len(s); { - if tryAddRuneSelf(entry, s[i]) { - i++ - continue - } - r, size := utf8.DecodeRuneInString(s[i:]) - if tryAddRuneError(entry, r, size) { - i++ - continue - } - entry.Buffer = append(entry.Buffer, s[i:i+size]...) - i += size - } -} - -// loggerFormatWriteBytes 方法安全写入[]byte的字符串数据。 -func loggerFormatWriteBytes(entry *LoggerStd, s []byte) { - for i := 0; i < len(s); { - if tryAddRuneSelf(entry, s[i]) { - i++ - continue - } - r, size := utf8.DecodeRune(s[i:]) - if tryAddRuneError(entry, r, size) { - i++ - continue - } - entry.Buffer = append(entry.Buffer, s[i:i+size]...) - i += size - } -} - -// tryAddRuneSelf appends b if it is valid UTF-8 character represented in a single byte. -func tryAddRuneSelf(entry *LoggerStd, b byte) bool { - if b >= utf8.RuneSelf { - return false - } - if 0x20 <= b && b != '\\' && b != '"' { - entry.Buffer = append(entry.Buffer, b) - return true - } - switch b { - case '\\', '"': - entry.Buffer = append(entry.Buffer, '\\') - entry.Buffer = append(entry.Buffer, b) - case '\n': - entry.Buffer = append(entry.Buffer, '\\') - entry.Buffer = append(entry.Buffer, 'n') - case '\r': - entry.Buffer = append(entry.Buffer, '\\') - entry.Buffer = append(entry.Buffer, 'r') - case '\t': - entry.Buffer = append(entry.Buffer, '\\') - entry.Buffer = append(entry.Buffer, 't') - default: - // Encode bytes < 0x20, except for the escape sequences above. - entry.Buffer = append(entry.Buffer, `\u00`...) - entry.Buffer = append(entry.Buffer, _hex[b>>4]) - entry.Buffer = append(entry.Buffer, _hex[b&0xF]) - } - return true -} - -func tryAddRuneError(entry *LoggerStd, r rune, size int) bool { - if r == utf8.RuneError && size == 1 { - entry.Buffer = append(entry.Buffer, `\ufffd`...) - return true - } - return false -} - -// LoggerWriter 定义日志写入流,用于写入日志数据。 -type LoggerWriter interface { - Sync() error - io.Writer -} - -type syncWriterFile struct { - *bufio.Writer - file *os.File -} - -type syncWriterRotate struct { - name string - std bool - MaxSize uint64 - nextindex int - nexttime time.Time - nbytes uint64 - *bufio.Writer - file *os.File - newfn []func(string) -} - -// NewLoggerWriterStd 函数返回一个标准输出流的日志写入流。 -func NewLoggerWriterStd() LoggerWriter { - return os.Stdout -} - -// NewLoggerWriterFile 函数创建一个文件输出的日志写入流。 -func NewLoggerWriterFile(name string, std bool) (LoggerWriter, error) { - if name == "" { - return NewLoggerWriterStd(), nil - } - os.MkdirAll(filepath.Dir(name), 0644) - file, err := os.OpenFile(formatDateName(name), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - return nil, err - } - - if std { - return &syncWriterFile{bufio.NewWriter(io.MultiWriter(os.Stdout, file)), file}, nil - } - return &syncWriterFile{bufio.NewWriter(file), file}, nil -} - -// Sync 方法将缓冲数据写入到文件。 -func (w syncWriterFile) Sync() error { - w.Flush() - return w.file.Sync() -} +// GetCallerFuncFile 函数获得调用的文件位置和函数名称。 +// +// 文件位置会从第一个src后开始截取,处理gopath下文件位置。 +func GetCallerFuncFile(depth int) (string, string) { + var pcs [1]uintptr + runtime.Callers(depth+1, pcs[:]) + fs := runtime.CallersFrames(pcs[:]) + f, _ := fs.Next() -// NewLoggerWriterRotate 函数创建一个支持文件切割的的日志写入流。 -func NewLoggerWriterRotate(name string, std bool, maxsize uint64, fn ...func(string)) (LoggerWriter, error) { - if strings.Index(name, "index") == -1 { - maxsize = 0 - } - if maxsize <= 0 { - // 如果同时文件名称不包含日期,那么就具有index和date日志滚动条件。 - if name == formatDateName(name) { - return NewLoggerWriterFile(name, std) - } - maxsize = 0xffffffffff - } - lw := &syncWriterRotate{ - name: name, - std: std, - MaxSize: maxsize, - nexttime: getNextHour(), - newfn: fn, - } - return lw, lw.rotateFile() + return trimFuncName(f.Function), trimFileName(f.File + ":" + strconv.Itoa(f.Line)) } -// Sync 方法将缓冲数据写入到文件。 -func (w *syncWriterRotate) Sync() error { - if w.file == nil { +// GetCallerStacks 函数返回caller栈信息。 +func GetCallerStacks(depth int) []string { + pc := make([]uintptr, DefaultLoggerDepthMaxStack) + n := runtime.Callers(depth, pc) + if n == 0 { return nil } - w.Flush() - return w.file.Sync() -} - -// Write 方法写入日志数据。 -func (w *syncWriterRotate) Write(p []byte) (n int, err error) { - if w.nbytes+uint64(len(p)) >= w.MaxSize { - // 执行size滚动 - w.rotateFile() - } - if time.Now().After(w.nexttime) { - w.nexttime = getNextHour() - // 检查时间变化 - if strings.Replace(formatDateName(w.name), "index", fmt.Sprint(w.nextindex-1), -1) != w.file.Name() { - w.nextindex = 0 - w.rotateFile() - } - } - n, err = w.Writer.Write(p) - if w.std { - os.Stdout.Write(p) - } - w.nbytes += uint64(n) - return -} -func (w *syncWriterRotate) rotateFile() error { - name := formatDateName(w.name) - for { - name := strings.Replace(name, "index", fmt.Sprint(w.nextindex), -1) - os.MkdirAll(filepath.Dir(name), 0644) - file, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - return err - } - w.nextindex++ - // 检查open新文件size小于MaxSize - stat, _ := file.Stat() - w.nbytes = uint64(stat.Size()) - if w.nbytes < w.MaxSize { - w.Sync() - w.file.Close() - w.Writer = bufio.NewWriter(file) - w.file = file - for _, fn := range w.newfn { - fn(name) - } - return nil - } - file.Close() - } -} - -func formatDateName(name string) string { - now := time.Now() - name = strings.Replace(name, "yyyy", "2006", 1) - name = strings.Replace(name, "yy", "06", 1) - name = strings.Replace(name, "MM", "01", 1) - name = strings.Replace(name, "dd", "02", 1) - name = strings.Replace(name, "HH", "15", 1) - return now.Format(name) -} - -func getNextHour() time.Time { - now := time.Now() - return time.Date(now.Year(), now.Month(), now.Day(), now.Hour()+1, 0, 0, 0, now.Location()) -} - -func newLoggerLinkName(link string) func(string) { - os.MkdirAll(filepath.Dir(link), 0644) - return func(name string) { - if link == "" { - return - } - if name[0] != '/' { - pwd, _ := os.Getwd() - name = filepath.Join(pwd, name) - } - os.Remove(link) - os.Symlink(name, link) + stack := make([]string, 0, n) + fs := runtime.CallersFrames(pc[:n]) + f, more := fs.Next() + for more { + stack = append(stack, trimFileName(f.File+":"+strconv.Itoa(f.Line))+" "+trimFuncName(f.Function)) + f, more = fs.Next() } + return stack } diff --git a/loggerformatter.go b/loggerformatter.go new file mode 100644 index 0000000..279a549 --- /dev/null +++ b/loggerformatter.go @@ -0,0 +1,627 @@ +package eudore + +/* +在encoding/json基础上 +Func/Chan/UnsafePointer类型输出指针地址 +Ptr/Map/Slice类型循环引用时输出指针地址 +Invalid类型输出null +Map类型不基于Key进行排序 +将fmt.Stringer/errorj接口转换字符串 +*/ + +import ( + "encoding" + "encoding/json" + "fmt" + "reflect" + "strconv" + "sync" + "unicode/utf8" + "unsafe" +) + +var ( + loggerLevelDefaultBytes = [][]byte{ + []byte("DEBUG"), []byte("INFO"), + []byte("WARNING"), []byte("ERROR"), []byte("FATAL"), + } + loggerLevelDefaultLen = []int{5, 4, 7, 5, 5} + loggerLevelColorBytes = [][]byte{ + []byte("\x1b[37mDEBUG\x1b[0m"), + []byte("\x1b[36mINFO\x1b[0m"), []byte("\x1b[33mWARNING\x1b[0m"), + []byte("\x1b[31mERROR\x1b[0m"), []byte("\x1b[31mFATAL\x1b[0m"), + } + loggerendpart1 = []byte("\"}\r\n") + loggerendpart2 = []byte("}\r\n") + _hex = "0123456789abcdef" + storageJSONEncoder sync.Map + storageTextEncoder sync.Map + tableEncodeTypePtr = [...]bool{ + reflect.Chan: true, + reflect.Func: true, + reflect.Interface: true, + reflect.Map: true, + reflect.Ptr: true, + reflect.Slice: true, + reflect.UnsafePointer: true, + } + tableEncodeTypeValue = [...]bool{ + reflect.Bool: true, reflect.Int: true, reflect.Uint: true, reflect.String: true, + reflect.Int8: true, reflect.Int16: true, reflect.Int32: true, reflect.Int64: true, + reflect.Uint8: true, reflect.Uint16: true, reflect.Uint32: true, reflect.Uint64: true, + reflect.Float32: true, reflect.Float64: true, + reflect.Complex64: true, reflect.Complex128: true, + reflect.Chan: true, reflect.Func: true, + reflect.Uintptr: true, reflect.UnsafePointer: true, + } +) + +type loggerFormatterText struct { + TimeFormat string +} + +// NewLoggerFormatterText 函数创建文件行格式日志格式化。 +func NewLoggerFormatterText(timeformat string) LoggerHandler { + return &loggerFormatterText{ + TimeFormat: timeformat + " ", + } +} + +func (h *loggerFormatterText) HandlerPriority() int { return 30 } +func (h *loggerFormatterText) HandlerEntry(entry *LoggerEntry) { + en := &loggerEncoder{ + data: entry.Buffer, + } + en.data = entry.Time.AppendFormat(en.data, h.TimeFormat) + en.data = append(en.data, loggerLevelDefaultBytes[entry.Level]...) + pos := sliceLastIndex(entry.Keys, "file") + if pos != -1 { + if file, ok := entry.Vals[pos].(string); ok { + en.data = append(en.data, ' ') + en.formatString(file) + entry.Keys = entry.Keys[:pos+copy(entry.Keys[pos:], entry.Keys[pos+1:])] + entry.Vals = entry.Vals[:pos+copy(entry.Vals[pos:], entry.Vals[pos+1:])] + } + } + if entry.Message != "" { + en.data = append(en.data, ' ') + en.data = append(en.data, []byte(entry.Message)...) + } + + for i := range entry.Keys { + en.data = append(en.data, ' ') + en.formatString(entry.Keys[i]) + en.data = append(en.data, '=') + en.formatText(reflect.ValueOf(entry.Vals[i])) + } + en.data = append(en.data, '\r', '\n') + entry.Buffer = en.data +} + +type loggerFormatterJSON struct { + TimeFormat string + KeyMessage []byte + KeyTime []byte + KeyLevel []byte +} + +// NewLoggerStdDataJSON 函数创建一个LoggerStd的JSON数据处理器。 +// +// 如果设置EnvEudoreDaemonEnable表示为后台运行,非终端启动会自动设置Std=false; +// 在非windows系统下,仅输出到终端不输出到文件时Level关键字会设置为彩色。 +func NewLoggerFormatterJSON(timeformat string) LoggerHandler { + return &loggerFormatterJSON{ + TimeFormat: timeformat, + KeyTime: []byte(`{"` + DefaultLoggerFormatterKeyTime + `":"`), + KeyLevel: []byte(`","` + DefaultLoggerFormatterKeyLevel + `":"`), + KeyMessage: []byte(`,"` + DefaultLoggerFormatterKeyMessage + `":"`), + } +} + +func (h *loggerFormatterJSON) HandlerPriority() int { return 30 } +func (h *loggerFormatterJSON) HandlerEntry(entry *LoggerEntry) { + en := &loggerEncoder{ + data: entry.Buffer, + } + en.data = append(en.data, h.KeyTime...) + en.data = entry.Time.AppendFormat(en.data, h.TimeFormat) + en.data = append(en.data, h.KeyLevel...) + en.data = append(en.data, loggerLevelDefaultBytes[entry.Level]...) + en.data = append(en.data, '"') + + for i := range entry.Keys { + en.data = append(en.data, ',', '"') + en.data = append(en.data, entry.Keys[i]...) + en.data = append(en.data, '"', ':') + en.formatJSON(reflect.ValueOf(entry.Vals[i])) + } + + if len(entry.Message) > 0 { + en.data = append(en.data, h.KeyMessage...) + en.formatString(entry.Message) + en.data = append(en.data, loggerendpart1...) + } else { + en.data = append(en.data, loggerendpart2...) + } + entry.Buffer = en.data +} + +type loggerEncoder struct { + data []byte + pointers []uintptr +} + +type typeEncoder func(*loggerEncoder, reflect.Value) + +func (en *loggerEncoder) formatText(v reflect.Value) { + if !v.IsValid() { + en.WriteString("null") + return + } + + t := v.Type() + if tableEncodeTypeValue[t.Kind()] && t.NumMethod() == 0 { + valueEncoder(en, v) + return + } else if v.Kind() == reflect.Ptr && !v.IsNil() { + if t.Implements(typeError) { + errorEncoder(en, v) + return + } else if t.Implements(typeFmtStringer) { + fmtStringerEncoder(en, v) + return + } + v = reflect.Indirect(v) + en.WriteBytes('&') + } + + newTextEncoder(v.Type())(en, v) +} + +func (en *loggerEncoder) formatJSON(v reflect.Value) { + if !v.IsValid() { + en.WriteString("null") + return + } + + newJSONEncoder(v.Type())(en, v) +} + +func newTextEncoder(t reflect.Type) typeEncoder { + if tableEncodeTypeValue[t.Kind()] && t.NumMethod() == 0 { + return valueEncoder + } + e, ok := storageTextEncoder.Load(t) + if ok { + return e.(typeEncoder) + } + + e = parseTextEncoder(t) + storageTextEncoder.Store(t, e) + return e.(typeEncoder) +} + +func parseTextEncoder(t reflect.Type) typeEncoder { + if t.Implements(typeError) { + return errorEncoder + } else if t.Implements(typeFmtStringer) { + return fmtStringerEncoder + } + + switch t.Kind() { + case reflect.Struct: + e := &structEncoder{} + storageTextEncoder.Store(t, typeEncoder(e.encodeText)) + e.Fields = parseTextStructFields(t) + return e.encodeText + case reflect.Map: + e := mapEnocder{newTextEncoder(t.Key()), newTextEncoder(t.Elem()), parseTextMapName(t)} + e.Prefix += "{" + return e.encode + case reflect.Slice, reflect.Array: + e := sliceEnocder{newTextEncoder(t.Elem())} + return e.encode + case reflect.Interface: + e := anyEnocder{newTextEncoder} + return e.encode + default: + return valueEncoder + } +} + +func parseTextStructFields(iType reflect.Type) []encodeJSONField { + var fields []encodeJSONField + for i := 0; i < iType.NumField(); i++ { + t := iType.Field(i) + fields = append(fields, encodeJSONField{ + Index: i, + Name: t.Name, + Encoder: newTextEncoder(t.Type), + }) + } + return fields +} + +func parseTextMapName(iType reflect.Type) string { + if iType.Name() == "" { + return "map" + } + return iType.String() +} + +func newJSONEncoder(t reflect.Type) typeEncoder { + if tableEncodeTypeValue[t.Kind()] && t.NumMethod() == 0 { + return valueEncoder + } + e, ok := storageJSONEncoder.Load(t) + if ok { + return e.(typeEncoder) + } + + e = parseJSONEncoder(t) + storageJSONEncoder.Store(t, e) + return e.(typeEncoder) +} + +func parseJSONEncoder(t reflect.Type) typeEncoder { + switch { + case t.Implements(typeJSONMarshaler): + return jsonMarshalerEncoder + case t.Implements(typeTextMarshaler): + return textMarshalerEncoder + case t.Implements(typeError): + return errorEncoder + case t.Implements(typeFmtStringer): + return fmtStringerEncoder + } + + switch t.Kind() { + case reflect.Struct: + e := &structEncoder{} + storageJSONEncoder.Store(t, typeEncoder(e.encode)) + e.Fields = parseJSONStructFields(t) + return e.encode + case reflect.Map: + e := mapEnocder{newJSONEncoder(t.Key()), newJSONEncoder(t.Elem()), ""} + e.Prefix = "{" + return e.encode + case reflect.Slice, reflect.Array: + e := sliceEnocder{newJSONEncoder(t.Elem())} + return e.encode + case reflect.Ptr: + e := ptrEnocder{newJSONEncoder(t.Elem())} + return e.encode + case reflect.Interface: + e := anyEnocder{newJSONEncoder} + return e.encode + default: + return valueEncoder + } +} + +func parseJSONStructFields(iType reflect.Type) []encodeJSONField { + var fields []encodeJSONField + for i := 0; i < iType.NumField(); i++ { + t := iType.Field(i) + name, omit := cutOmit(t.Tag.Get("json")) + if name == "-" || t.Name[0] < 'A' || t.Name[0] > 'Z' { + continue + } + if name == "" { + name = t.Name + } + + field := encodeJSONField{ + Index: i, + Name: name, + Omit: omit, + } + if t.Anonymous && t.Type.Kind() == reflect.Struct { + e := &structEncoder{} + e.Fields = parseJSONStructFields(t.Type) + field.Anonymous = true + field.Encoder = e.encodeFields + } else { + field.Encoder = newJSONEncoder(t.Type) + } + fields = append(fields, field) + } + return fields +} + +type ptrEnocder struct { + Elem typeEncoder +} + +func (e ptrEnocder) encode(en *loggerEncoder, v reflect.Value) { + if en.formatVia(v) { + return + } + defer en.releaseVia() + e.Elem(en, v.Elem()) +} + +type anyEnocder struct { + newEncoder func(reflect.Type) typeEncoder +} + +func (e anyEnocder) encode(en *loggerEncoder, v reflect.Value) { + if v.IsNil() { + en.WriteString("null") + return + } + v = v.Elem() + e.newEncoder(v.Type())(en, v) +} + +type mapEnocder struct { + Key typeEncoder + Val typeEncoder + Prefix string +} + +func (e mapEnocder) encode(en *loggerEncoder, v reflect.Value) { + if en.formatVia(v) { + return + } + defer en.releaseVia() + en.WriteString(e.Prefix) + pos := len(en.data) + i := v.MapRange() + for i.Next() { + e.Key(en, i.Key()) + en.WriteBytes(':') + e.Val(en, i.Value()) + en.WriteBytes(',') + } + if pos == len(en.data) { + en.WriteBytes('}') + } else { + en.data[len(en.data)-1] = '}' + } +} + +type sliceEnocder struct { + Elem typeEncoder +} + +func (e sliceEnocder) encode(en *loggerEncoder, v reflect.Value) { + if en.formatVia(v) { + return + } + defer en.releaseVia() + en.WriteBytes('[') + pos := len(en.data) + for i := 0; i < v.Len(); i++ { + e.Elem(en, v.Index(i)) + en.WriteBytes(',') + } + if pos == len(en.data) { + en.WriteBytes(']') + } else { + en.data[len(en.data)-1] = ']' + } +} + +type structEncoder struct { + Fields []encodeJSONField +} +type encodeJSONField struct { + Index int + Name string + Omit bool + Anonymous bool + Encoder typeEncoder +} + +func (e *structEncoder) encodeText(en *loggerEncoder, v reflect.Value) { + en.WriteBytes('{') + pos := len(en.data) + for _, f := range e.Fields { + en.WriteString(f.Name) + en.WriteBytes(':') + f.Encoder(en, v.Field(f.Index)) + en.WriteBytes(' ') + } + if pos == len(en.data) { + en.WriteBytes('}') + } else { + en.data[len(en.data)-1] = '}' + } +} + +func (e *structEncoder) encode(en *loggerEncoder, v reflect.Value) { + en.WriteBytes('{') + pos := len(en.data) + e.encodeFields(en, v) + if pos == len(en.data) { + en.WriteBytes('}') + } else { + en.data[len(en.data)-1] = '}' + } +} + +func (e *structEncoder) encodeFields(en *loggerEncoder, v reflect.Value) { + for _, f := range e.Fields { + v := v.Field(f.Index) + if f.Anonymous { + f.Encoder(en, v) + continue + } + + if f.Omit && v.IsZero() { + continue + } + en.WriteBytes('"') + en.WriteString(f.Name) + en.WriteBytes('"', ':') + f.Encoder(en, v) + en.WriteBytes(',') + } +} + +func valueEncoder(en *loggerEncoder, v reflect.Value) { + // 写入类型 + switch v.Kind() { + case reflect.Bool: + en.data = strconv.AppendBool(en.data, v.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + en.data = strconv.AppendInt(en.data, v.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, + reflect.Uint32, reflect.Uint64, reflect.Uintptr: + en.data = strconv.AppendUint(en.data, v.Uint(), 10) + case reflect.Float32, reflect.Float64: + en.data = strconv.AppendFloat(en.data, v.Float(), 'f', -1, 64) + case reflect.Complex64, reflect.Complex128: + val := v.Complex() + en.WriteBytes('"') + en.data = strconv.AppendFloat(en.data, real(val), 'f', -1, 64) + en.WriteBytes('+') + en.data = strconv.AppendFloat(en.data, imag(val), 'f', -1, 64) + en.WriteBytes('i', '"') + case reflect.String: + en.WriteBytes('"') + en.formatString(v.String()) + en.WriteBytes('"') + case reflect.Ptr, reflect.Func, reflect.Chan, reflect.UnsafePointer: + if v.IsNil() { + en.WriteString("null") + return + } + en.WriteBytes('"', '0', 'x') + en.data = strconv.AppendUint(en.data, uint64(v.Pointer()), 16) + en.WriteBytes('"') + } +} + +func errorEncoder(en *loggerEncoder, v reflect.Value) { + if tableEncodeTypePtr[v.Kind()] && v.IsNil() { + en.WriteString("null") + return + } + en.WriteBytes('"') + en.WriteString(v.Interface().(error).Error()) + en.WriteBytes('"') +} + +func fmtStringerEncoder(en *loggerEncoder, v reflect.Value) { + if tableEncodeTypePtr[v.Kind()] && v.IsNil() { + en.WriteString("null") + return + } + en.WriteBytes('"') + en.WriteString(v.Interface().(fmt.Stringer).String()) + en.WriteBytes('"') +} + +func jsonMarshalerEncoder(en *loggerEncoder, v reflect.Value) { + if tableEncodeTypePtr[v.Kind()] && v.IsNil() { + en.WriteString("null") + return + } + body, err := v.Interface().(json.Marshaler).MarshalJSON() + if err == nil { + en.WriteBytes(body...) + } else { + en.WriteBytes('"') + en.formatString(err.Error()) + en.WriteBytes('"') + } +} + +func textMarshalerEncoder(en *loggerEncoder, v reflect.Value) { + if tableEncodeTypePtr[v.Kind()] && v.IsNil() { + en.WriteString("null") + return + } + body, err := v.Interface().(encoding.TextMarshaler).MarshalText() + en.WriteBytes('"') + if err == nil { + en.formatString(*(*string)(unsafe.Pointer(&body))) + } else { + en.formatString(err.Error()) + } + en.WriteBytes('"') +} + +func (en *loggerEncoder) formatVia(v reflect.Value) bool { + if v.IsNil() { + en.WriteString("null") + return true + } + ptr := v.Pointer() + if en.pointers == nil { + en.pointers = make([]uintptr, 0, 4) + } + for _, p := range en.pointers { + if p == ptr { + en.WriteBytes('"', '0', 'x') + en.data = strconv.AppendUint(en.data, uint64(p), 16) + en.WriteBytes('"') + return true + } + } + en.pointers = append(en.pointers, ptr) + return false +} + +func (en *loggerEncoder) releaseVia() { + en.pointers = en.pointers[:len(en.pointers)-1] +} + +// formatString 方法安全写入字符串。 +func (en *loggerEncoder) formatString(s string) { + for i := 0; i < len(s); { + b := s[i] + if b < utf8.RuneSelf { + en.addRuneSelf(b) + i++ + continue + } + r, size := utf8.DecodeRuneInString(s[i:]) + switch r { + case utf8.RuneError: + if size == 1 { + en.WriteString(`\ufffd`) + } + case '\u2028', '\u2029': + en.WriteString(`\u202`) + en.WriteBytes(_hex[r&0xF]) + default: + en.WriteString(s[i : i+size]) + } + i += size + } +} + +func (en *loggerEncoder) addRuneSelf(b byte) { + if 0x20 <= b && b != '\\' && b != '"' { + en.WriteBytes(b) + return + } + switch b { + case '\\', '"': + en.WriteBytes('\\', b) + case '\n': + en.WriteBytes('\\', 'n') + case '\r': + en.WriteBytes('\\', 'r') + case '\t': + en.WriteBytes('\\', 't') + default: + en.WriteString(`\u00`) + en.WriteBytes(_hex[b>>4], _hex[b&0xF]) + } +} + +func (en *loggerEncoder) WriteBytes(b ...byte) { + en.data = append(en.data, b...) +} + +func (en *loggerEncoder) WriteString(s string) { + b := *(*[]byte)(unsafe.Pointer(&struct { + string + Cap int + }{s, len(s)})) + en.data = append(en.data, b...) +} diff --git a/loggerhandler.go b/loggerhandler.go new file mode 100644 index 0000000..e90128a --- /dev/null +++ b/loggerhandler.go @@ -0,0 +1,472 @@ +package eudore + +import ( + "bytes" + "context" + "fmt" + "os" + "path" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +type loggerHandlerInit struct { + sync.Mutex + Entrys []LoggerEntry +} + +func (h *loggerHandlerInit) HandlerPriority() int { + return 100 +} + +func (h *loggerHandlerInit) HandlerEntry(entry *LoggerEntry) { + entry.Time = time.Now() + h.Lock() + h.Entrys = append(h.Entrys, *entry) + h.Unlock() +} + +// Unmount 方法获取ContextKeyLogger.(Logger)接受Init存储的日志。 +func (h *loggerHandlerInit) Unmount(ctx context.Context) { + h.Lock() + defer h.Unlock() + logger, _ := ctx.Value(ContextKeyLogger).(Logger) + if logger == nil { + logger = NewLogger(nil) + } + + logger = logger.WithField("depth", "disable").WithField("logger", true) + for _, data := range h.Entrys { + entry := logger.WithField("time", data.Time).WithFields(data.Keys, data.Vals) + switch data.Level { + case LoggerDebug: + entry.Debug(data.Message) + case LoggerInfo: + entry.Info(data.Message) + case LoggerWarning: + entry.Warning(data.Message) + case LoggerError: + entry.Error(data.Message) + case LoggerFatal: + entry.Fatal(data.Message) + } + } + h.Entrys = nil +} + +type loggerHookMeta struct { + Size uint64 + Count [5]uint64 +} + +// NewLoggerHookMeta 函数创建日志Meta处理,记录日志数量和写入量。 +func NewLoggerHookMeta() LoggerHandler { + return &loggerHookMeta{} +} + +func (h *loggerHookMeta) Metadata() any { + return MetadataLogger{ + Health: true, + Name: "eudore.loggerStd", + Count: h.Count, + Size: h.Size, + SizeFormat: formatSize(int64(h.Size)), + } +} +func (h *loggerHookMeta) HandlerPriority() int { return 60 } +func (h *loggerHookMeta) HandlerEntry(entry *LoggerEntry) { + atomic.AddUint64(&h.Size, uint64(len(entry.Buffer))) + atomic.AddUint64(&h.Count[entry.Level], 1) +} + +type loggerHookFilter struct { + Rules [][]string + Funcs [][]loggerHookFilterFunc +} + +type loggerHookFilterFunc struct { + Key string + FuncRunner +} + +// NewLoggerHookFilter 函数创建日志过滤处理器。 +// +// 在Mount时如果规则初始化失败,查看FuncCreator的Metadata。 +func NewLoggerHookFilter(rules [][]string) LoggerHandler { + for i := range rules { + rules[i] = sliceFilter(rules[i], func(t string) bool { + return len(strings.SplitN(t, " ", 3)) == 3 + }) + } + return &loggerHookFilter{ + Rules: rules, + Funcs: make([][]loggerHookFilterFunc, 0, len(rules)), + } +} + +// Mount 方法使LoggerStd挂载上下文,上下文传递给LoggerStdData。 +func (h *loggerHookFilter) Mount(ctx context.Context) { + fc := NewFuncCreatorWithContext(ctx) + for i := range h.Rules { + funcs := make([]loggerHookFilterFunc, 0, len(h.Rules[i])) + for j := range h.Rules[i] { + strs := strings.SplitN(h.Rules[i][j], " ", 3) + kind := NewFuncCreateKind(strs[1]) + fn, err := fc.CreateFunc(kind, strs[2]) + if err != nil { + continue + } + funcs = append(funcs, loggerHookFilterFunc{strs[0], FuncRunner{kind, fn}}) + } + if len(funcs) > 0 { + h.Funcs = append(h.Funcs, funcs) + } + } +} + +func (h *loggerHookFilter) HandlerPriority() int { return 10 } +func (h *loggerHookFilter) HandlerEntry(entry *LoggerEntry) { + for i := range h.Funcs { + h.HandlerRule(entry, h.Funcs[i]) + if entry.Level == LoggerDiscard { + return + } + } +} + +func (h *loggerHookFilter) HandlerRule(entry *LoggerEntry, funcs []loggerHookFilterFunc) { + for i := range funcs { + pos := sliceIndex(entry.Keys, funcs[i].Key) + if pos == -1 { + return + } + + kind := NewFuncCreateKindWithType(reflect.TypeOf(entry.Vals[pos])) + if kind != funcs[i].Kind && kind+FuncCreateNumber != funcs[i].Kind { + return + } + + if funcs[i].Kind > FuncCreateAny { + entry.Vals[pos] = funcRunAny(funcs[i].Kind, funcs[i].Func, entry.Vals[pos]) + } else if !funcs[i].RunPtr(reflect.ValueOf(entry.Vals[pos])) { + return + } + } + if len(entry.Keys) > 0 && funcs[len(funcs)-1].Kind < FuncCreateSetString { + entry.Level = LoggerDiscard + } +} + +func funcRunAny(kind FuncCreateKind, fn, i any) any { + v := reflect.Indirect(reflect.ValueOf(i)) + switch v.Kind() { + case reflect.Slice, reflect.Array: + r := FuncRunner{kind, fn} + for i := 0; i < v.Len(); i++ { + r.Run(v.Index(i)) + } + return i + } + + switch kind { + case FuncCreateSetString: + return fn.(func(string) string)(v.String()) + case FuncCreateSetInt: + return fn.(func(int) int)(int(v.Int())) + case FuncCreateSetUint: + return fn.(func(uint) uint)(uint(v.Uint())) + case FuncCreateSetFloat: + return fn.(func(float64) float64)(v.Float()) + case FuncCreateSetBool: + return fn.(func(bool) bool)(v.Bool()) + default: + return fn.(func(any) any)(i) + } +} + +type loggerHookFatal struct { + Callback func(*LoggerEntry) +} + +// NewLoggerHookFatal 函数创建Fatal级别日志处理Hook。 +func NewLoggerHookFatal(fn func(*LoggerEntry)) LoggerHandler { + return &loggerHookFatal{fn} +} + +func (h *loggerHookFatal) Mount(ctx context.Context) { + if h.Callback == nil { + app, ok := ctx.Value(ContextKeyApp).(*App) + if ok { + h.Callback = func(entry *LoggerEntry) { + app.SetValue(ContextKeyError, fmt.Errorf(entry.Message)) + } + } + } +} + +func (h *loggerHookFatal) HandlerPriority() int { return 101 } +func (h *loggerHookFatal) HandlerEntry(entry *LoggerEntry) { + if entry.Level == LoggerFatal { + if h.Callback == nil { + panic(entry.Message) + } + h.Callback(entry) + } +} + +type loggerWriterStdout struct { + sync.Mutex +} +type loggerWriterStdoutColor struct { + sync.Mutex +} + +func NewLoggerWriterStdout(color bool) LoggerHandler { + if color { + return &loggerWriterStdoutColor{} + } + return &loggerWriterStdout{} +} + +func (h *loggerWriterStdout) HandlerPriority() int { + return 90 +} + +func (h *loggerWriterStdout) HandlerEntry(entry *LoggerEntry) { + h.Lock() + _, _ = os.Stdout.Write(entry.Buffer) + h.Unlock() +} + +func (h *loggerWriterStdoutColor) HandlerPriority() int { + return 90 +} + +func (h *loggerWriterStdoutColor) HandlerEntry(entry *LoggerEntry) { + pos := bytes.Index(entry.Buffer[:64], loggerLevelDefaultBytes[entry.Level]) + h.Lock() + if pos != -1 { + _, _ = os.Stdout.Write(entry.Buffer[:pos]) + _, _ = os.Stdout.Write(loggerLevelColorBytes[entry.Level]) + _, _ = os.Stdout.Write(entry.Buffer[pos+loggerLevelDefaultLen[entry.Level]:]) + } else { + os.Stdout.Write(entry.Buffer) + } + h.Unlock() +} + +type loggerWriterFile struct { + sync.Mutex + File *os.File +} + +// NewLoggerWriterFile 函数创建一个文件输出的日志写入流。 +func NewLoggerWriterFile(name string) (LoggerHandler, error) { + err := os.MkdirAll(filepath.Dir(name), 0o644) + if err != nil { + return nil, err + } + file, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) + if err != nil { + return nil, err + } + + return &loggerWriterFile{File: file}, nil +} + +func (h *loggerWriterFile) HandlerPriority() int { + return 100 +} + +func (h *loggerWriterFile) HandlerEntry(entry *LoggerEntry) { + h.Lock() + _, _ = h.File.Write(entry.Buffer) + h.Unlock() +} + +type loggerWriterRotate struct { + loggerWriterFile + name string + pattern string + writeSize uint64 + maxSize uint64 + nextIndex int + nextTime time.Time + openhooks []func(string, string) +} + +// max uint64, 9999-12-31 23:59:59 +0000 UTC. +const roatteMaxSize, roatteMaxTime = 0xffffffffffffffff, 253402300799 + +// NewLoggerWriterRotate 函数创建一个支持文件切割的的日志写入流。 +// +// 如果设置maxsize或name包含字符串yyyy/yy/mm/dd/hh,将可以滚动日志文件。 +func NewLoggerWriterRotate(name string, maxsize uint64, fn ...func(string, string)) (LoggerHandler, error) { + if maxsize == 0 && getNextTime(name).Unix() == roatteMaxTime { + return NewLoggerWriterFile(name) + } + if maxsize == 0 { + maxsize = roatteMaxSize + } + h := &loggerWriterRotate{ + name: name, + maxSize: maxsize, + nextIndex: getNextIndex(name, maxsize), + nextTime: getNextTime(name), + openhooks: fn, + } + h.pattern = h.getFilePattern() + return h, h.rotateFile() +} + +func (h *loggerWriterRotate) HandlerEntry(entry *LoggerEntry) { + h.Lock() + defer h.Unlock() + if h.writeSize+uint64(len(entry.Buffer)) >= h.maxSize { + h.rotateFile() + } else if entry.Time.After(h.nextTime) { + h.nextIndex = getNextIndex(h.name, h.maxSize) + h.nextTime = getNextTime(h.name) + h.rotateFile() + } + + n, _ := h.File.Write(entry.Buffer) + h.writeSize += uint64(n) +} + +func (h *loggerWriterRotate) rotateFile() error { + for { + name := h.getRotateName() + _ = os.MkdirAll(filepath.Dir(name), 0o644) + file, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) + if err != nil { + return err + } + h.nextIndex++ + + stat, _ := file.Stat() + h.writeSize = uint64(stat.Size()) + if h.writeSize < h.maxSize { + h.File.Sync() + h.File.Close() + h.File = file + for _, fn := range h.openhooks { + fn(name, h.pattern) + } + return nil + } + file.Close() + } +} + +func (h *loggerWriterRotate) getRotateName() string { + name := h.name + if h.nextTime.Unix() != roatteMaxTime { + name = fileFormatTime(name) + } + if h.maxSize != roatteMaxSize { + ext := path.Ext(h.name) + name = name[:len(name)-len(ext)] + "-" + strconv.Itoa(h.nextIndex) + ext + } + return name +} + +func (h *loggerWriterRotate) getFilePattern() string { + name := h.name + if h.nextTime.Unix() != roatteMaxTime { + k := DefaultLoggerWriterRotateDataKeys + v := [...]int{2, 2, 2, 4} + for i := range k { + name = strings.ReplaceAll(name, k[i], strings.Repeat("[0-9]", v[i])) + } + } + if h.maxSize != roatteMaxSize { + ext := path.Ext(name) + name = name[:len(name)-len(ext)] + "-*" + ext + } + return name +} + +func fileFormatTime(n string) string { + now := time.Now() + k := DefaultLoggerWriterRotateDataKeys + v := [...]string{"15", "02", "01", "2006"} + for i := range k { + n = strings.ReplaceAll(n, k[i], now.Format(v[i])) + } + return n +} + +func getNextIndex(name string, size uint64) int { + index := 0 + if size != roatteMaxSize { + ext := path.Ext(name) + name := fileFormatTime(name[:len(name)-len(ext)] + "-") + list, _ := filepath.Glob(name + "*" + ext) + for i := range list { + n, _ := strconv.Atoi(list[i][len(name) : len(list[i])-len(ext)]) + if n > index { + index = n + } + } + } + return index +} + +func getNextTime(name string) time.Time { + for i, str := range DefaultLoggerWriterRotateDataKeys { + if strings.Contains(name, str) { + now := time.Now() + datas := [...]int{now.Hour(), now.Day(), int(now.Month()), now.Year()} + datas[i]++ + return time.Date(datas[3], time.Month(datas[2]), datas[1], datas[0], 0, 0, 0, now.Location()) + } + } + return time.Unix(roatteMaxTime, 0) +} + +func hookFileLink(link string) func(string, string) { + os.MkdirAll(filepath.Dir(link), 0o644) + return func(name, _ string) { + if !filepath.IsAbs(name) { + pwd, _ := os.Getwd() + name = filepath.Join(pwd, name) + } + os.Remove(link) + os.Symlink(name, link) + } +} + +func hookFileRecycle(age, count int) func(string, string) { + type fileTime struct { + Name string + ModTime time.Time + } + return func(_, pattern string) { + list, _ := filepath.Glob(pattern) + files := make([]fileTime, 0, len(list)) + for i := range list { + stat, _ := os.Stat(list[i]) + files = append(files, fileTime{list[i], stat.ModTime()}) + } + sort.Slice(files, func(i, j int) bool { + return files[i].ModTime.Before(files[j].ModTime) + }) + + if count < len(files) { + files = files[:len(files)-count] + expr := time.Now().Add(time.Hour * time.Duration(-age)) + for i := range files { + if files[i].ModTime.Before(expr) { + os.Remove(files[i].Name) + } + } + } + } +} diff --git a/middleware/README.md b/middleware/README.md index 78312b9..5e259f7 100644 --- a/middleware/README.md +++ b/middleware/README.md @@ -118,15 +118,16 @@ example: 创建响应压缩中间件,默认提供gzip和deflate压缩 参数: - string 压缩名称 -- func() interface{} 压缩器创建函数 +- func() any 压缩器创建函数 - int 压缩级别 example: ```golang import: "github.com/andybalholm/brotli" -app.AddMiddleware(middleware.NewCompressFunc("br", func() interface{} { return brotli.NewWriter(ioutil.Discard) })) -app.AddMiddleware(middleware.NewCompressGzipFunc(5)) -app.AddMiddleware(middleware.NewCompressDeflateFunc(5)) +app.AddMiddleware(middleware.NewCompressMixinsFunc(nil)) +app.AddMiddleware(middleware.NewCompressFunc("br", func() any { return brotli.NewWriter(ioutil.Discard) })) +app.AddMiddleware(middleware.NewCompressGzipFunc()) +app.AddMiddleware(middleware.NewCompressDeflateFunc()) ``` ## ContextWarp @@ -171,14 +172,14 @@ Cors中间件注册不是全局中间件时,需要最后注册一次Options /\ 校验设置CSRF token 参数: -- interface{} 指明获取csrf token的方法,下列是允许使用的值 +- any 指明获取csrf token的方法,下列是允许使用的值 - "csrf" - "query: csrf" - "header: X-CSRF-Token" - "form: csrf" - func(ctx eudore.Context) string {return ctx.Query("csrf")} - nil -- interface{} 指明设置Cookie的基础信息,下列是允许使用的值 +- any 指明设置Cookie的基础信息,下列是允许使用的值 - "csrf" - http.Cookie{Name: "csrf"} - nil @@ -256,9 +257,9 @@ example: 实现请求令牌桶限流/限速 参数: -- int 每周期(默认秒)增加speed个令牌 -- int 最多拥有的令牌数量 -- ...interface{} 额外使用的Options,根据类型来断言设置选项 +- int 每周期(默认秒)增加speed个令牌 +- int 最多拥有的令牌数量 +- ...any 额外使用的Options,根据类型来断言设置选项 context.Context => 控制cleanupVisitors退出的生命周期 time.Duration => 基础时间周期单位,默认秒 func(eudore.Context) string => 限流获取key的函数,默认Context.ReadIP @@ -340,10 +341,10 @@ app.AddMiddleware("global", middleware.NewRewriteFunc(map[string]string{ 用于执行额外的路由匹配行为 参数: -- map[string]interface{} 请求路径对应的执行函数,路径前缀不指定方法则为Any方法 +- map[string]any 请求路径对应的执行函数,路径前缀不指定方法则为Any方法 example: ``` -app.AddMiddleware(middleware.NewRouterFunc(map[string]interface{}{ +app.AddMiddleware(middleware.NewRouterFunc(map[string]any{ "/api/:v/*": func(ctx eudore.Context) { ctx.Request().URL.Path = "/api/v3/" + ctx.GetParam("*") }, diff --git a/middleware/all.go b/middleware/all.go index f043764..fdcb4de 100644 --- a/middleware/all.go +++ b/middleware/all.go @@ -12,26 +12,14 @@ import ( "github.com/eudore/eudore" ) -type responseMessage struct { - Time string `json:"time"` - Host string `json:"host"` - Method string `json:"method"` - Path string `json:"path"` - Route string `json:"route"` - Status int `json:"status"` - Message string `json:"message,omitempty"` - Error string `json:"error,omitempty"` - Stack []string `json:"stack,omitempty"` - Size int64 `json:"size,omitempty"` - XRequestID string `json:"x-request-id,omitempty"` - XTraceID string `json:"x-trace-id,omitempty"` -} - // HandlerAdmin 函数返回Admin UI界面。 func HandlerAdmin(ctx eudore.Context) { - ctx.SetHeader("X-Eudore-Admin", "ui") + ctx.SetHeader(eudore.HeaderXEudoreAdmin, "ui") ctx.SetHeader(eudore.HeaderContentType, eudore.MimeTextHTMLCharsetUtf8) - http.ServeContent(ctx.Response(), ctx.Request(), "admin.html", now, strings.NewReader(AdminStatic)) + http.ServeContent( + ctx.Response(), ctx.Request(), "admin.html", + now, strings.NewReader(AdminStatic), + ) } // NewBasicAuthFunc 创建一个Basic auth认证中间件。 @@ -63,37 +51,31 @@ func NewBasicAuthFunc(names map[string]string) eudore.HandlerFunc { func NewBodyLimitFunc(size int64) eudore.HandlerFunc { return func(ctx eudore.Context) { req := ctx.Request() - if req.ContentLength > size { + switch { + case req.Body == http.NoBody: + case req.ContentLength > size: + ctx.SetHeader(eudore.HeaderConnection, "close") ctx.WriteHeader(http.StatusRequestEntityTooLarge) - ctx.Render(eudore.NewContextMessgae(ctx, nil, fmt.Sprintf(eudore.ErrFormatMiddlewareRequestEntityTooLargeSzie, req.ContentLength))) + _ = ctx.Render(eudore.NewContextMessgae(ctx, nil, &http.MaxBytesError{Limit: size})) ctx.End() - return + default: + var w http.ResponseWriter = ctx.Response() + for { + unwraper, ok := w.(interface{ Unwrap() http.ResponseWriter }) + if !ok { + break + } + w = unwraper.Unwrap() + } + req.Body = http.MaxBytesReader(w, req.Body, size) } - - req.Body = &limitedReader{req.Body, size} } } -type limitedReader struct { - io.ReadCloser // underlying reader - N int64 // max bytes remaining -} - -func (l *limitedReader) Read(p []byte) (n int, err error) { - if l.N <= 0 { - return 0, eudore.ErrMiddlewareRequestEntityTooLarge - } - if int64(len(p)) > l.N { - p = p[0:l.N] - } - n, err = l.ReadCloser.Read(p) - l.N -= int64(n) - return -} - // NewContextWarpFunc 函数中间件使之后的处理函数使用的eudore.Context对象为新的Context。 // -// 装饰器下可以直接对Context进行包装,而责任链下无法修改Context主体故设计该中间件作为中间件执行机制补充。 +// 装饰器下可以直接对Context进行包装, +// 而责任链下无法修改Context主体故设计该中间件作为中间件执行机制补充。 func NewContextWarpFunc(fn func(eudore.Context) eudore.Context) eudore.HandlerFunc { return func(ctx eudore.Context) { index, handler := ctx.GetHandler() @@ -167,10 +149,17 @@ func NewHeaderWithSecureFunc(h http.Header) eudore.HandlerFunc { // NewHeaderFilteFunc 函数创建请求header过滤中间件,对来源于外部ip请求,过滤指定header。 func NewHeaderFilteFunc(iplist, names []string) eudore.HandlerFunc { if iplist == nil { - iplist = []string{"10.0.0.0/8", "172.16.0.0/12", "192.0.0.0/24", "127.0.0.1", "127.0.0.10"} + iplist = []string{ + "10.0.0.0/8", "172.16.0.0/12", "192.0.0.0/24", + "127.0.0.1", "127.0.0.10", + } } if names == nil { - names = []string{eudore.HeaderXRealIP, eudore.HeaderXForwardedFor, eudore.HeaderXForwardedHost, eudore.HeaderXForwardedProto, eudore.HeaderXRequestID, eudore.HeaderXTraceID} + names = []string{ + eudore.HeaderXRealIP, eudore.HeaderXForwardedFor, + eudore.HeaderXForwardedHost, eudore.HeaderXForwardedProto, + eudore.HeaderXRequestID, eudore.HeaderXTraceID, + } } var list BlackNode for _, ip := range iplist { @@ -194,22 +183,29 @@ func NewHeaderFilteFunc(iplist, names []string) eudore.HandlerFunc { // NewLoggerFunc 函数创建一个请求日志记录中间件。 // -// app参数传入*eudore.App需要使用其Logger输出日志,paramsh获取Context.Params如果不为空则添加到输出日志条目中 +// log参数设置用于输出eudore.Logger, +// params获取Context.Params如果不为空则添加到输出日志条目中 // -// 状态码如果为40x、50x输出日志级别为Error。 +// 状态码如果为50x输出日志级别为Error。 func NewLoggerFunc(log eudore.Logger, params ...string) eudore.HandlerFunc { log = log.WithField("depth", "disable").WithField("logger", true) - keys := []string{"method", "path", "realip", "proto", "host", "status", "request-time", "size"} - headerkeys := [...]string{eudore.HeaderXRequestID, eudore.HeaderXTraceID, eudore.HeaderLocation} + keys := [...]string{ + "method", "path", "realip", "proto", "host", "status", "request-time", "size", + } + headerkeys := [...]string{ + eudore.HeaderXRequestID, + eudore.HeaderXTraceID, + eudore.HeaderLocation, + } headernames := [...]string{"x-request-id", "x-trace-id", "location"} return func(ctx eudore.Context) { now := time.Now() ctx.Next() status := ctx.Response().Status() // 连续WithField保证field顺序 - out := log.WithFields(keys, []interface{}{ + out := log.WithFields(keys[:], []any{ ctx.Method(), ctx.Path(), ctx.RealIP(), ctx.Request().Proto, - ctx.Host(), status, time.Now().Sub(now).String(), ctx.Response().Size(), + ctx.Host(), status, time.Since(now).String(), ctx.Response().Size(), }) for _, param := range params { @@ -249,7 +245,7 @@ func NewLoggerLevelFunc(fn func(ctx eudore.Context) int) eudore.HandlerFunc { fn = func(ctx eudore.Context) int { level := ctx.GetQuery("eudore_debug") if level != "" { - return eudore.GetStringInt(level) + return eudore.GetAnyByString[int](level) } return -1 } @@ -277,25 +273,25 @@ func NewRecoverFunc() eudore.HandlerFunc { if !ok { err = fmt.Errorf("%v", r) } - stack := eudore.GetPanicStack(3) + stack := eudore.GetCallerStacks(3) ctx.WithField("stack", stack).Error(err) if ctx.Response().Size() == 0 { ctx.WriteHeader(eudore.StatusInternalServerError) - ctx.Render(eudore.NewContextMessgae(ctx, err, stack)) + _ = ctx.Render(eudore.NewContextMessgae(ctx, err, stack)) } }() ctx.Next() } } -// NewRequestIDFunc 函数创建一个请求ID注入处理函数,不给定请求ID创建函数,默认使用时间戳和随机数,会将request-id写入协议和附加到日志field。 +// NewRequestIDFunc 函数创建一个请求ID注入处理函数,不给定请求ID创建函数, +// 默认使用时间戳和随机数,会将request-id写入协议和附加到日志field。 func NewRequestIDFunc(fn func(eudore.Context) string) eudore.HandlerFunc { if fn == nil { fn = func(eudore.Context) string { randkey := make([]byte, 3) - io.ReadFull(rand.Reader, randkey) + _, _ = io.ReadFull(rand.Reader, randkey) return fmt.Sprintf("%d-%x", time.Now().UnixNano(), randkey) - } } return func(ctx eudore.Context) { @@ -304,6 +300,9 @@ func NewRequestIDFunc(fn func(eudore.Context) string) eudore.HandlerFunc { requestID = fn(ctx) } ctx.SetHeader(eudore.HeaderXRequestID, requestID) - ctx.SetValue(eudore.ContextKeyLogger, ctx.Value(eudore.ContextKeyLogger).(eudore.Logger).WithField("x-request-id", requestID).WithField("logger", true)) + + log := ctx.Value(eudore.ContextKeyLogger).(eudore.Logger) + log = log.WithField("x-request-id", requestID).WithField("logger", true) + ctx.SetValue(eudore.ContextKeyLogger, log) } } diff --git a/middleware/black.go b/middleware/black.go index 217e4f7..fafab90 100644 --- a/middleware/black.go +++ b/middleware/black.go @@ -8,9 +8,6 @@ import ( "github.com/eudore/eudore" ) -// BlackInvalidAddress 定义解析无效地址时使用的默认地址,127.0.0.2。 -var BlackInvalidAddress uint64 = 2130706434 - // Black 定义黑名单中间件后台。 type black struct { White *BlackNode @@ -51,17 +48,17 @@ func (b *black) InjectRoutes(router eudore.Router) { router.DeleteFunc("/black/black/:ip black=black", b.DeleteAllIP) } -func (b *black) data(ctx eudore.Context) interface{} { - ctx.SetHeader("X-Eudore-Admin", "black") - return map[string]interface{}{ +func (b *black) data(ctx eudore.Context) any { + ctx.SetHeader(eudore.HeaderXEudoreAdmin, "black") + return map[string]any{ "white": b.White.List(nil, 0, 32), "black": b.Black.List(nil, 0, 32), } } func (b *black) putIP(ctx eudore.Context) { - ip := fmt.Sprintf("%s/%s", ctx.GetParam("ip"), eudore.GetString(ctx.GetQuery("mask"), "32")) - ctx.Infof("%s insert %s ip: %s", ctx.RealIP(), eudore.GetString(ctx.GetQuery("black"), "white"), ip) + ip := fmt.Sprintf("%s/%s", ctx.GetParam("ip"), eudore.GetAnyByString(ctx.GetQuery("mask"), "32")) + ctx.Infof("%s insert %s ip: %s", ctx.RealIP(), eudore.GetAnyByString(ctx.GetQuery("black"), "white"), ip) if ctx.GetParam("black") != "" { b.InsertBlack(ip) } else { @@ -70,8 +67,8 @@ func (b *black) putIP(ctx eudore.Context) { } func (b *black) DeleteAllIP(ctx eudore.Context) { - ip := fmt.Sprintf("%s/%s", ctx.GetParam("ip"), eudore.GetString(ctx.GetQuery("mask"), "32")) - ctx.Infof("%s DeleteAll %s ip: %s", ctx.RealIP(), eudore.GetString(ctx.GetQuery("black"), "white"), ip) + ip := fmt.Sprintf("%s/%s", ctx.GetParam("ip"), eudore.GetAnyByString(ctx.GetQuery("mask"), "32")) + ctx.Infof("%s DeleteAll %s ip: %s", ctx.RealIP(), eudore.GetAnyByString(ctx.GetQuery("black"), "white"), ip) if ctx.GetParam("black") != "" { b.DeleteAllBlack(ip) } else { @@ -86,7 +83,7 @@ func (b *black) HandleHTTP(ctx eudore.Context) { return } if b.Black.Look(ip) { - ctx.WriteHeader(403) + ctx.WriteHeader(eudore.StatusForbidden) ctx.WriteString("black list deny your ip " + ctx.RealIP()) ctx.End() } @@ -212,7 +209,7 @@ func ip2intbit(ip string) (uint64, uint) { func ip2int(ip string) uint64 { bits := strings.Split(ip, ".") if len(bits) != 4 { - return BlackInvalidAddress + return DefaultBlackInvalidAddress } b0, _ := strconv.Atoi(bits[0]) b1, _ := strconv.Atoi(bits[1]) diff --git a/middleware/breaker.go b/middleware/breaker.go index 8f0ad09..b7c38dc 100644 --- a/middleware/breaker.go +++ b/middleware/breaker.go @@ -14,7 +14,7 @@ const ( BreakerStatueOpen ) -// BreakerStatues 定义熔断状态字符串 +// BreakerStatues 定义熔断状态字符串。 var BreakerStatues = []string{"closed", "half-open", "open"} // BreakerState 是熔断器状态。 @@ -120,7 +120,7 @@ func (b *Breaker) data(ctx eudore.Context) { } func (b *Breaker) getRoute(ctx eudore.Context) { - id := eudore.GetStringInt(ctx.GetParam("id"), -1) + id := eudore.GetAnyByString(ctx.GetParam("id"), -1) if id < 0 || id > b.Index { ctx.Fatal("id is invalid") return @@ -132,8 +132,8 @@ func (b *Breaker) getRoute(ctx eudore.Context) { } func (b *Breaker) putRouteState(ctx eudore.Context) { - id := eudore.GetStringInt(ctx.GetParam("id"), -1) - state := eudore.GetStringInt(ctx.GetParam("state")) + id := eudore.GetAnyByString(ctx.GetParam("id"), -1) + state := eudore.GetAnyByString[int](ctx.GetParam("state")) if id < 0 || id > b.Index { ctx.Fatal("id is invalid") return @@ -158,14 +158,14 @@ func (c *breakRoute) Handle(ctx eudore.Context) { isdeny := c.BreakerState == BreakerStatueOpen || (c.BreakerState == BreakerStatueHalfOpen && c.OnHalfOpen()) c.Unlock() if isdeny { - ctx.WriteHeader(503) + ctx.WriteHeader(eudore.StatusServiceUnavailable) ctx.Fatal("Breaker deny request: " + c.Name) ctx.End() return } ctx.Next() c.Lock() - if ctx.Response().Status() < 500 { + if ctx.Response().Status() < eudore.StatusInternalServerError { c.TotalSuccesses++ c.ConsecutiveSuccesses++ c.ConsecutiveFailures = 0 @@ -204,7 +204,7 @@ func (c *breakRoute) RetryClose() { } } -// String 方法实现string接口 +// String 方法实现string接口。 func (state BreakerState) String() string { return BreakerStatues[state] } diff --git a/middleware/cache.go b/middleware/cache.go index ad78703..dc700f3 100644 --- a/middleware/cache.go +++ b/middleware/cache.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "context" + "fmt" "net/http" "sync" "time" @@ -16,28 +17,28 @@ type cache struct { context context.Context getKeyFunc func(eudore.Context) string waits map[string]*sync.WaitGroup - cacheStore + CacheStore } -type cacheStore interface { +type CacheStore interface { Load(string) *CacheData Store(string, *CacheData) } -// NewCacheFunc 函数创建一个缓存中间件,对Get请求具有缓存和SingleFlight双重效果,无法获得中间件之前的响应header数据。 -// -// options: -// -// context.Context => 控制默认cacheMap清理过期数据的生命周期 -// -// time.Duration => 请求数据缓存时间,默认秒 -// -// func(eudore.Context) string => 自定义缓存key,为空则跳过缓存 -// -// cacheStore => 缓存存储对象 -func NewCacheFunc(args ...interface{}) eudore.HandlerFunc { +/* +NewCacheFunc 函数创建一个缓存中间件,对Get请求具有缓存和SingleFlight双重效果, +无法获得中间件之前的响应header数据。 + +options: + + context.Context => 控制默认cacheMap清理过期数据的生命周期。 + time.Duration => 请求数据缓存时间,默认秒。 + func(eudore.Context) string => 自定义缓存key,为空则跳过缓存。 + CacheStore => 缓存存储对象。 +*/ +func NewCacheFunc(args ...any) eudore.HandlerFunc { c := &cache{ - dura: time.Second, + dura: DefaultCacheSaveTime, context: context.Background(), getKeyFunc: func(ctx eudore.Context) string { if ctx.Method() != eudore.MethodGet || ctx.GetHeader(eudore.HeaderUpgrade) != "" { @@ -55,29 +56,34 @@ func NewCacheFunc(args ...interface{}) eudore.HandlerFunc { c.context = val case func(eudore.Context) string: c.getKeyFunc = val - case cacheStore: - c.cacheStore = val + case CacheStore: + c.CacheStore = val } } - if c.cacheStore == nil { - c.cacheStore = newCacheMap(c.context, c.dura) + if c.CacheStore == nil { + c.CacheStore = newCacheMap(c.context, c.dura) } return c.Handle } func (cache *cache) Handle(ctx eudore.Context) { key := cache.getKeyFunc(ctx) - if key == "" { + if key == "" || ctx.GetHeader(eudore.HeaderConnection) == "Upgrade" { return } + fullkey := fmt.Sprintf("%s:%s:%s", key, + ctx.GetHeader(eudore.HeaderAccept), + ctx.GetHeader(eudore.HeaderAcceptEncoding), + ) + var wait *sync.WaitGroup var ok bool for { // load cache - data := cache.Load(key) + data := cache.Load(fullkey) if data != nil { - data.writeData(ctx.Response()) + data.writeData(ctx) ctx.SetParam("cache", key) ctx.End() return @@ -85,11 +91,11 @@ func (cache *cache) Handle(ctx eudore.Context) { // cas cache.Lock() - wait, ok = cache.waits[key] + wait, ok = cache.waits[fullkey] if !ok { wait = new(sync.WaitGroup) wait.Add(1) - cache.waits[key] = wait + cache.waits[fullkey] = wait cache.Unlock() break } @@ -97,23 +103,26 @@ func (cache *cache) Handle(ctx eudore.Context) { wait.Wait() } - w := ctx.Response() - resp := &cacheResponset{ - ResponseWriter: w, - header: make(http.Header), + now := time.Now() + w := &responseWriterCache{ + ResponseWriter: ctx.Response(), + CacheHeader: http.Header{ + eudore.HeaderLastModified: {now.UTC().Format(http.TimeFormat)}, + }, } - ctx.SetResponse(resp) - ctx.Next() ctx.SetResponse(w) - cache.Store(key, &CacheData{ - Expired: time.Now().Add(cache.dura), - Status: w.Status(), - Header: w.Header(), - Body: resp.Buffer.Bytes(), + ctx.Next() + ctx.SetResponse(w.ResponseWriter) + cache.Store(fullkey, &CacheData{ + Expired: now.Add(cache.dura), + ModifiedTime: w.CacheHeader.Get(eudore.HeaderLastModified), + Status: w.Status(), + Header: w.CacheHeader, + Body: w.CacheData.Bytes(), }) cache.Lock() - delete(cache.waits, key) + delete(cache.waits, fullkey) wait.Done() cache.Unlock() } @@ -149,7 +158,7 @@ func (m *cacheMap) Run(ctx context.Context, t time.Duration) { for { select { case now := <-time.After(t): - m.Map.Range(func(key, value interface{}) bool { + m.Map.Range(func(key, value any) bool { item := value.(*CacheData) if now.After(item.Expired) { m.Map.Delete(key) @@ -162,46 +171,67 @@ func (m *cacheMap) Run(ctx context.Context, t time.Duration) { } } -// cacheResponset 对象记录返回的响应数据 +// responseWriterCache 对象记录返回的响应数据 // // Upgrade请求不会进入cache处理,push处理仅push主请求,缓存请求不push无明显影响。 -type cacheResponset struct { +type responseWriterCache struct { eudore.ResponseWriter - header http.Header - bytes.Buffer + CacheData bytes.Buffer + CacheHeader http.Header + ModifiedTime string } // Write 方法实现ResponseWriter中的Write方法。 -func (w *cacheResponset) Write(data []byte) (int, error) { +func (w *responseWriterCache) Write(data []byte) (int, error) { if w.Size() == 0 { h := w.ResponseWriter.Header() - for k, v := range w.header { + h.Add(eudore.HeaderLastModified, w.ModifiedTime) + for k, v := range w.CacheHeader { h[k] = v } } - w.Buffer.Write(data) + w.CacheData.Write(data) return w.ResponseWriter.Write(data) } +func (w *responseWriterCache) WriteString(data string) (int, error) { + if w.Size() == 0 { + h := w.ResponseWriter.Header() + h.Add(eudore.HeaderLastModified, w.ModifiedTime) + for k, v := range w.CacheHeader { + h[k] = v + } + } + w.CacheData.WriteString(data) + return w.ResponseWriter.WriteString(data) +} + // Header 方法返回响应设置的header。 -func (w *cacheResponset) Header() http.Header { - return w.header +func (w *responseWriterCache) Header() http.Header { + return w.CacheHeader } // CacheData 定义缓存的数据类型。 type CacheData struct { - Expired time.Time - Status int - Header http.Header - Body []byte + Expired time.Time + ModifiedTime string + Status int + Header http.Header + Body []byte } // writeData 方法将cache响应数据写入到请求响应。 -func (w *CacheData) writeData(resp eudore.ResponseWriter) { - resp.WriteHeader(w.Status) - h := resp.Header() +func (w *CacheData) writeData(ctx eudore.Context) { + ctx.SetHeader(eudore.HeaderXEudoreCache, "true") + if ctx.GetHeader(eudore.HeaderIfModifiedSince) == w.ModifiedTime { + ctx.SetHeader(eudore.HeaderLastModified, w.ModifiedTime) + ctx.WriteHeader(eudore.StatusNotModified) + return + } + h := ctx.Response().Header() for k, v := range w.Header { h[k] = v } - resp.Write(w.Body) + ctx.WriteHeader(w.Status) + ctx.Write(w.Body) } diff --git a/middleware/compress.go b/middleware/compress.go index 63bb242..96e7cd6 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -4,7 +4,6 @@ import ( "compress/flate" "compress/gzip" "io" - "io/ioutil" "net/http" "strings" "sync" @@ -12,67 +11,118 @@ import ( "github.com/eudore/eudore" ) -// NewCompressFunc 函数创建一个压缩响应处理函数,需要指定压缩算法和压缩对象构造函数。 -// -// import: "github.com/andybalholm/brotli" +// NewCompressGzipFunc 函数创建一个gzip压缩处理函数。 +func NewCompressGzipFunc() eudore.HandlerFunc { + return NewCompressFunc(CompressNameGzip, func() any { + return gzip.NewWriter(io.Discard) + }) +} + +// NewCompressDeflateFunc 函数创建一个deflate压缩处理函数。 +func NewCompressDeflateFunc() eudore.HandlerFunc { + return NewCompressFunc(CompressNameDeflate, func() any { + w, _ := flate.NewWriter(io.Discard, flate.DefaultCompression) + return w + }) +} + +func newResponseWriterCompressPool(name string, fn func() any) *sync.Pool { + return &sync.Pool{ + New: func() any { + return &responseWriterCompress{ + Name: name, + Writer: fn().(compressor), + Buffer: make([]byte, CompressBufferLength), + } + }, + } +} + +// NewCompressFunc 函数创建一个压缩处理函数,需要指定压缩算法和压缩对象构造函数。 +func NewCompressFunc(name string, fn func() any) eudore.HandlerFunc { + pool := newResponseWriterCompressPool(name, fn) + return func(ctx eudore.Context) { + // 检查是否使用压缩 + if !strings.Contains(ctx.GetHeader(eudore.HeaderAcceptEncoding), name) || + ctx.Response().Header().Get(eudore.HeaderContentEncoding) != "" || + strings.Contains(ctx.GetHeader(eudore.HeaderConnection), "Upgrade") || + strings.Contains(ctx.GetHeader(eudore.HeaderContentType), "text/event-stream") { + return + } + + handlerCompress(ctx, pool) + } +} + +// NewCompressMixinsFunc 函数创建一个混合压缩处理函数,默认具有gzip、defalte。 // -// br: middleware.NewCompressFunc("br", func() interface{} { return brotli.NewWriter(ioutil.Discard) }), +// 如果压缩ResponseWriter.Size()值为压缩后size。 // -// gzip: middleware.NewCompressGzipFunc(5) +// 如果设置middleware.DefaultComoressBrotliFunc指定brotli压缩函数,追加br压缩。 // -// deflate: middleware.NewCompressDeflateFunc(5) -func NewCompressFunc(name string, fn func() interface{}) eudore.HandlerFunc { - pool := sync.Pool{New: fn} - return func(ctx eudore.Context) { - // 检查是否使用name压缩 - if shouldNotCompress(ctx, name) { - ctx.Next() - return +// HeaderAcceptEncoding值忽略非零权重值,顺序优先。 +func NewCompressMixinsFunc(compresss map[string]func() any) eudore.HandlerFunc { + if compresss == nil { + compresss = make(map[string]func() any) + compresss[CompressNameGzip] = func() any { + return gzip.NewWriter(io.Discard) + } + compresss[CompressNameDeflate] = func() any { + w, _ := flate.NewWriter(io.Discard, flate.DefaultCompression) + return w } - // 初始化ResponseWriter - w := &responseCompress{ - ResponseWriter: ctx.Response(), - Writer: pool.Get().(compressor), - Name: name, + if DefaultComoressBrotliFunc != nil { + compresss[CompressNameBrotli] = DefaultComoressBrotliFunc } - w.Writer.Reset(ctx.Response()) + } + names := make([]string, 0, len(compresss)) + pools := make([]*sync.Pool, 0, len(compresss)) + for name := range compresss { + names = append(names, name) + pools = append(pools, newResponseWriterCompressPool(name, compresss[name])) + } - ctx.SetResponse(w) - ctx.SetHeader(eudore.HeaderContentEncoding, name) - ctx.SetHeader(eudore.HeaderVary, eudore.HeaderAcceptEncoding) - ctx.Next() + return func(ctx eudore.Context) { + encoding := ctx.GetHeader(eudore.HeaderAcceptEncoding) + if encoding == "" || encoding == CompressNameIdentity || + ctx.Response().Header().Get(eudore.HeaderContentEncoding) != "" || + strings.Contains(ctx.GetHeader(eudore.HeaderConnection), "Upgrade") || + strings.Contains(ctx.GetHeader(eudore.HeaderContentType), "text/event-stream") { + return + } - w.Writer.Close() - pool.Put(w.Writer) + for _, encoding := range strings.Split(encoding, ",") { + name, quality, ok := strings.Cut(strings.TrimSpace(encoding), ";") + if ok && quality == "q=0" { + continue + } + for i := range names { + if names[i] == name { + handlerCompress(ctx, pools[i]) + } + } + } } } -// NewCompressGzipFunc 函数创建一个gzip压缩响应处理函数,如果压缩级别超出gzip范围默认使用5。 -func NewCompressGzipFunc(level int) eudore.HandlerFunc { - if level < gzip.HuffmanOnly || level > gzip.BestCompression { - level = 5 - } - return NewCompressFunc("gzip", func() interface{} { - gz, _ := gzip.NewWriterLevel(ioutil.Discard, level) - return gz - }) -} +func handlerCompress(ctx eudore.Context, pool *sync.Pool) { + // 初始化ResponseWriter + w := pool.Get().(*responseWriterCompress) + w.ResponseWriter = ctx.Response() + w.Buffer = w.Buffer[0:0] + w.State = CompressStateUnknown + defer w.Close(pool) -// NewCompressDeflateFunc 函数创建一个deflate压缩响应处理函数,如果压缩级别超出deflate范围默认使用5。 -func NewCompressDeflateFunc(level int) eudore.HandlerFunc { - if level < flate.HuffmanOnly || level > flate.BestCompression { - level = 5 - } - return NewCompressFunc("deflate", func() interface{} { - gz, _ := flate.NewWriter(ioutil.Discard, level) - return gz - }) + ctx.SetResponse(w) + ctx.Next() } -// responseCompress 定义Gzip响应,实现ResponseWriter接口 -type responseCompress struct { +// responseWriterCompress 定义压缩响应,实现ResponseWriter接口。 +type responseWriterCompress struct { eudore.ResponseWriter Writer compressor + State int + Buffer []byte Name string } @@ -83,58 +133,116 @@ type compressor interface { Close() error } -// Write 实现ResponseWriter中的Write方法。 -func (w responseCompress) Write(data []byte) (int, error) { - return w.Writer.Write(data) +// Unwrap 方法返回原始http.ResponseWrite对象。 +func (w *responseWriterCompress) Unwrap() http.ResponseWriter { + return w.ResponseWriter } -// Flush 实现ResponseWriter中的Flush方法。 -func (w responseCompress) Flush() { - w.Writer.Flush() - w.ResponseWriter.Flush() +// Write 实现ResponseWriter中的Write方法。 +func (w *responseWriterCompress) Write(data []byte) (int, error) { + switch w.State { + case CompressStateEnable: + return w.Writer.Write(data) + case CompressStateDisable: + return w.ResponseWriter.Write(data) + default: + if len(data)+len(w.Buffer) <= CompressBufferLength { + w.State = CompressStateBuffer + w.Buffer = append(w.Buffer, data...) + return len(data), nil + } + + w.init() + return w.Write(data) + } } -func shouldNotCompress(ctx eudore.Context, name string) bool { - h := ctx.Request().Header - if !strings.Contains(h.Get(eudore.HeaderAcceptEncoding), name) || - strings.Contains(h.Get(eudore.HeaderConnection), "Upgrade") || - strings.Contains(h.Get(eudore.HeaderContentType), "text/event-stream") { +func (w *responseWriterCompress) WriteString(data string) (int, error) { + switch w.State { + case CompressStateEnable: + // WriteString only gzip + return io.WriteString(w.Writer, data) + case CompressStateDisable: + return w.ResponseWriter.WriteString(data) + default: + if len(data)+len(w.Buffer) <= CompressBufferLength { + w.State = CompressStateBuffer + w.Buffer = append(w.Buffer, data...) + return len(data), nil + } + + w.init() + return w.WriteString(data) + } +} - return true +func (w *responseWriterCompress) WriteHeader(code int) { + if code == 200 { + return + } + if w.State < CompressStateEnable { + contentlength := w.ResponseWriter.Header().Get(eudore.HeaderContentLength) + if contentlength == "" || len(contentlength) > 3 { + w.init() + } + w.ResponseWriter.WriteHeader(code) } +} - return ctx.Response().Header().Get(eudore.HeaderContentEncoding) != "" +// Flush 实现ResponseWriter中的Flush方法。 +func (w *responseWriterCompress) Flush() { + switch w.State { + case CompressStateEnable: + w.Writer.Flush() + case CompressStateBuffer: + w.init() + w.Writer.Flush() + } + w.ResponseWriter.Flush() } -// Push initiates an HTTP/2 server push. -// Push returns ErrNotSupported if the client has disabled push or if push -// is not supported on the underlying connection. -func (w *responseCompress) Push(target string, opts *http.PushOptions) error { - return w.ResponseWriter.Push(target, w.setAcceptEncodingForPushOptions(opts)) +func (w *responseWriterCompress) Close(pool *sync.Pool) { + switch w.State { + case CompressStateEnable: + w.Writer.Close() + w.Writer.Reset(io.Discard) + case CompressStateBuffer: + w.ResponseWriter.Write(w.Buffer) + } + pool.Put(w) } -// setAcceptEncodingForPushOptions sets "Accept-Encoding" : "gzip" for PushOptions without overriding existing headers. -func (w *responseCompress) setAcceptEncodingForPushOptions(opts *http.PushOptions) *http.PushOptions { - if opts == nil { - opts = &http.PushOptions{ - Header: http.Header{ - eudore.HeaderAcceptEncoding: []string{w.Name}, - }, - } - return opts +func (w *responseWriterCompress) init() { + h := w.ResponseWriter.Header() + contenttype := h.Get(eudore.HeaderContentType) + pos := strings.IndexByte(contenttype, ';') + if pos != -1 { + contenttype = contenttype[:pos] } - if opts.Header == nil { - opts.Header = http.Header{ - eudore.HeaderAcceptEncoding: []string{w.Name}, - } - return opts + if DefaultComoressDisableMime[contenttype] || h.Get(eudore.HeaderContentEncoding) != "" { + w.State = CompressStateDisable + } else { + w.State = CompressStateEnable + w.Writer.Reset(w.ResponseWriter) + h.Del(eudore.HeaderContentLength) + h.Set(eudore.HeaderContentEncoding, w.Name) + h.Set(eudore.HeaderVary, strings.Join(append(h.Values(eudore.HeaderVary), eudore.HeaderAcceptEncoding), ", ")) } + if len(w.Buffer) > 0 { + w.Write(w.Buffer) + } +} - if encoding := opts.Header.Get(eudore.HeaderAcceptEncoding); encoding == "" { +// Push 方法给Push Header设置HeaderAcceptEncoding。 +func (w *responseWriterCompress) Push(target string, opts *http.PushOptions) error { + switch { + case opts == nil: + opts = &http.PushOptions{Header: http.Header{eudore.HeaderAcceptEncoding: {w.Name}}} + case opts.Header == nil: + opts.Header = http.Header{eudore.HeaderAcceptEncoding: {w.Name}} + case opts.Header.Get(eudore.HeaderAcceptEncoding) == "": opts.Header.Add(eudore.HeaderAcceptEncoding, w.Name) - return opts } - - return opts + return w.ResponseWriter.Push(target, opts) } diff --git a/middleware/const.go b/middleware/const.go new file mode 100644 index 0000000..72c20e0 --- /dev/null +++ b/middleware/const.go @@ -0,0 +1,66 @@ +package middleware + +import ( + "errors" + "net/http" + "net/http/pprof" + "time" +) + +const ( + CompressBufferLength = 1024 + CompressStateUnknown = iota + CompressStateBuffer + CompressStateEnable + CompressStateDisable + CompressNameGzip = "gzip" + CompressNameBrotli = "br" + CompressNameDeflate = "deflate" + CompressNameIdentity = "identity" + MimeValueJSON = "value/json" + MimeValueHTML = "value/html" + MimeValueText = "value/text" + QueryFormatJSON = "json" + QueryFormatHTML = "html" + QueryFormatText = "text" +) + +var ( + // DefaultBlackInvalidAddress 定义解析无效地址时使用的默认地址,127.0.0.2。 + DefaultBlackInvalidAddress uint64 = 2130706434 + DefaultCacheSaveTime = time.Second * 10 + // DefaultComoressBrotliFunc 指定brotli压缩构造函数。 + // + // import "github.com/andybalholm/brotli" + // + // middleware.DefaultComoressBrotliFunc = func() any {return brotli.NewWriter(io.Discard)} . + DefaultComoressBrotliFunc func() any + DefaultComoressDisableMime = map[string]bool{ + "application/gzip": true, // gz + "application/zip": true, // zip + "application/x-compressed-tar": true, // tar.gz + "application/x-7z-compressed": true, // 7z + "application/x-rar-compressed": true, // rar + "image/gif": true, // gif + "image/jpeg": true, // jpeg + "image/png": true, // png + "image/svg+xml": true, // svg + "image/webp": true, // webp + "font/woff2": true, // woff2 + } + + DefaultPprofHandlers = map[string]http.Handler{ + "cmdline": http.HandlerFunc(pprof.Cmdline), + "profile": http.HandlerFunc(pprof.Profile), + "symbol": http.HandlerFunc(pprof.Symbol), + "trace": http.HandlerFunc(pprof.Trace), + "allocs": pprof.Handler("allocs"), + "block": pprof.Handler("block"), + "heap": pprof.Handler("heap"), + "mutex": pprof.Handler("mutex"), + "threadcreate": pprof.Handler("threadcreate"), + } + + ErrRateReadWaitLong = errors.New("if the github.com/eudore/eudore/middleware speed limit waiting time is too long, it will time out") + ErrRateWriteWaitLong = errors.New("if the github.com/eudore/eudore/middleware speed limit waits for write time is too long, it will wait for timeout") +) diff --git a/middleware/cors.go b/middleware/cors.go index 3ce0e13..a526764 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -9,69 +9,69 @@ import ( // NewCorsFunc 函数创建一个Cors处理函数。 // -// origins是允许的origin,headers是跨域验证成功的添加的headers,例如:"Access-Control-Allow-Credentials"、"Access-Control-Allow-Headers"等。 +// pattens是允许的origin,headers是跨域验证成功的添加的headers,例如:"Access-Control-Allow-Credentials"、"Access-Control-Allow-Headers"等。 // -// 如果origins为空,设置为*。 +// 如果pattens为空,允许任意origin。 // 如果Access-Control-Allow-Methods header为空,设置为*。 // // Cors中间件注册不是全局中间件时,需要最后注册一次Options /*或404方法,否则Options请求匹配了默认404没有经过Cors中间件处理。 -func NewCorsFunc(origins []string, headers map[string]string) eudore.HandlerFunc { - if len(origins) == 0 { - origins = []string{"*"} - } +func NewCorsFunc(pattens []string, headers map[string]string) eudore.HandlerFunc { corsHeaders := make(map[string]string, len(headers)) for k, v := range headers { corsHeaders[textproto.CanonicalMIMEHeaderKey(k)] = v } - if corsHeaders["Access-Control-Allow-Methods"] == "" { - corsHeaders["Access-Control-Allow-Methods"] = "*" + if corsHeaders[eudore.HeaderAccessControlAllowMethods] == "" { + corsHeaders[eudore.HeaderAccessControlAllowMethods] = "*" } return func(ctx eudore.Context) { - origin := ctx.GetHeader("Origin") + origin := ctx.GetHeader(eudore.HeaderOrigin) + pos := strings.Index(origin, "://") + if pos != -1 { + origin = origin[pos+3:] + } // 检查是否未同源请求,cors和upgrade时存在origin header。 - if origin == "" || ctx.GetHeader(eudore.HeaderUpgrade) != "" { + if origin == "" || origin == ctx.Host() { return } - origin = strings.TrimPrefix(strings.TrimPrefix(origin, "http://"), "https://") - if !validateOrigin(origins, origin) { - ctx.WriteHeader(403) + if !validateOrigin(pattens, origin) { + ctx.WriteHeader(eudore.StatusForbidden) ctx.End() return } h := ctx.Response().Header() - h.Add("Access-Control-Allow-Origin", ctx.GetHeader("Origin")) + h.Add(eudore.HeaderAccessControlAllowOrigin, ctx.GetHeader(eudore.HeaderOrigin)) if ctx.Method() == eudore.MethodOptions { for k, v := range corsHeaders { h[k] = append(h[k], v) } - ctx.WriteHeader(204) + ctx.WriteHeader(eudore.StatusNoContent) ctx.End() } } } // validateOrigin 方法检查origin是否合法。 -func validateOrigin(origins []string, origin string) bool { - for _, i := range origins { - if matchStar(origin, i) { +func validateOrigin(pattens []string, origin string) bool { + for _, patten := range pattens { + if matchStar(patten, origin) { return true } } - return false + return pattens == nil } // matchStar 模式匹配对象,允许使用带'*'的模式。 -func matchStar(obj, patten string) bool { - ps := strings.Split(patten, "*") - if len(ps) < 2 { +func matchStar(patten, obj string) bool { + parts := strings.Split(patten, "*") + if len(parts) < 2 { return patten == obj } - if !strings.HasPrefix(obj, ps[0]) { + if !strings.HasPrefix(obj, parts[0]) { return false } - for _, i := range ps { + for _, i := range parts { if i == "" { continue } diff --git a/middleware/csrf.go b/middleware/csrf.go index 0f0cb8f..83dd31d 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -12,30 +12,25 @@ import ( "github.com/eudore/eudore" ) -// NewCsrfFunc 函数创建一个Csrf处理函数,key指定请求带有crsf参数的关键字,cookie是csrf设置cookie的基本详细。 -// -// key value: -// -// - "csrf" -// -// - "query: csrf" -// -// - "header: X-CSRF-Token" -// -// - "form: csrf" -// -// - func(ctx eudore.Context) string {return ctx.Query("csrf")} -// -// - nil -// -// cookie value: -// -// - "csrf" -// -// - http.Cookie{Name: "csrf"} -// -// - nil -func NewCsrfFunc(key, cookie interface{}) eudore.HandlerFunc { +/* +NewCsrfFunc 函数创建一个Csrf处理函数,key指定请求带有crsf参数的关键字,cookie是csrf设置cookie的基本详细。 + +key value: + + "csrf" + "query: csrf" + "header: X-CSRF-Token" + "form: csrf" + func(ctx eudore.Context) string {return ctx.Query("csrf")} + nil + +cookie value: + + "csrf" + http.Cookie{Name: "csrf"} + nil +*/ +func NewCsrfFunc(key, cookie any) eudore.HandlerFunc { keyfunc := getCsrfTokenFunc(key) basecookie := getCsrfBaseCookie(cookie) return func(ctx eudore.Context) { @@ -60,7 +55,7 @@ func NewCsrfFunc(key, cookie interface{}) eudore.HandlerFunc { } // getCsrfBaseCookie 函数创建应该CSRF基础Cookie。 -func getCsrfBaseCookie(cookie interface{}) http.Cookie { +func getCsrfBaseCookie(cookie any) http.Cookie { switch val := cookie.(type) { case http.Cookie: return val @@ -76,7 +71,7 @@ func getCsrfBaseCookie(cookie interface{}) http.Cookie { // getCsrfTokenFunc 函数根据key一个csrf token获取函数。 // // 如果key是字符串类型通过query、header、form前缀返回对应获得token方法;如果key为func(eudore.Context) string类型直接返回;否在返回默认函数。 -func getCsrfTokenFunc(key interface{}) func(eudore.Context) string { +func getCsrfTokenFunc(key any) func(eudore.Context) string { switch val := key.(type) { case string: switch { @@ -93,10 +88,13 @@ func getCsrfTokenFunc(key interface{}) func(eudore.Context) string { case strings.HasPrefix(val, "form:"): val = strings.TrimSpace(val[5:]) return func(ctx eudore.Context) string { - if strings.Index(ctx.GetHeader(eudore.HeaderContentType), eudore.MimeMultipartForm) == -1 { - return "" + contenttype := ctx.GetHeader(eudore.HeaderContentType) + if ctx.Request().Body == http.NoBody || + strings.Contains(contenttype, eudore.MimeApplicationForm) || + strings.Contains(contenttype, eudore.MimeMultipartForm) { + return ctx.FormValue(val) } - return ctx.FormValue(val) + return "" } } case func(eudore.Context) string: diff --git a/middleware/doc.go b/middleware/doc.go index 414750f..725b108 100644 --- a/middleware/doc.go +++ b/middleware/doc.go @@ -1,32 +1,41 @@ /* Package middleware 实现eudore基础请求中间件和处理函数。 -BasicAuth +# BasicAuth 实现请求BasicAuth访问认证 参数: + map[string]string 允许的用户名和密码的键值对map。 + example: + app.AddMiddleware(middleware.NewBasicAuthFunc(map[string]string{"user": "pw"})) -BodyLimit +# BodyLimit 限制请求body大小 参数: + int64 指定限制body的长度 + examole: + app.AddMiddleware(middleware.NewBodyLimitFunc(32 << 20)) -Black +# Black 实现黑白名单管理及管理后台 参数: + map[string]bool 指明初始化使用的黑白名单,true为白白名单/false为黑名单 eudore.Router 为注入黑名单管理路由的路由器。 + example: + app.AddMiddleware(middleware.NewBlackFunc(map[string]bool{ "192.168.100.0/24": true, "192.168.75.0/30": true, @@ -36,13 +45,16 @@ example: "0.0.0.0/0": false, }, app.Group("/eudore/debug"))) -Breaker +# Breaker 实现路由规则熔断 参数: + eudore.Router + 属性: + MaxConsecutiveSuccesses uint32 最大连续成功次数 MaxConsecutiveFailures uint32 最大连续失败次数 OpenWait time.Duration 打开状态恢复到半开状态下等待时间 @@ -58,50 +70,63 @@ example: 在关闭状态下连续错误一定次数后熔断器进入半开状态;在半开状态下请求将进入限流状态,半开连续错误一定次数后进入打开状态,半开连续成功一定次数后回到关闭状态;在进入关闭状态后等待一定时间后恢复到半开状态。 -Cache +# Cache 创建一个缓存中间件,对Get请求具有缓存和SingleFlight双重效果。 参数: + context.Context 控制默认cacheMap清理过期数据的生命周期 time.Duration 请求数据缓存时间,默认秒 cacheStore 缓存存储对象 + example: + app.AddMiddleware(middleware.NewCacheFunc(time.Second*10, app.Context)) -Compress +# Compress 创建响应压缩中间件,默认提供gzip和deflate压缩 参数: + string 压缩名称 - func() interface{} 压缩器创建函数 + func() any 压缩器创建函数 int 压缩级别 + example: + import: "github.com/andybalholm/brotli" - app.AddMiddleware(middleware.NewCompressFunc("br", func() interface{} { return brotli.NewWriter(ioutil.Discard) })) - app.AddMiddleware(middleware.NewCompressGzipFunc(5)) - app.AddMiddleware(middleware.NewCompressDeflateFunc(5)) + app.AddMiddleware(middleware.NewCompressMixinsFunc(nil)) + app.AddMiddleware(middleware.NewCompressFunc("br", func() any { return brotli.NewWriter(io.Discard) })) + app.AddMiddleware(middleware.NewCompressGzipFunc()) + app.AddMiddleware(middleware.NewCompressDeflateFunc()) -ContextWarp +# ContextWarp 使中间件之后的处理函数使用的eudore.Context对象为新的Context 参数: + func(eudore.Context) eudore.Context 指定ContextWarp使用的eudore.Context封装函数 + example: + app.AddMiddleware(middleware.NewContextWarpFunc(newContextParams)) func newContextParams(ctx eudore.Context) eudore.Context { return contextParams{ctx} } -Cors +# Cors 跨域请求 参数: + []string 允许使用的origin,默认值为:[]string{"*"} map[string]string CORS验证通过后给请求添加的协议headers,用来设置CORS控制信息 + example: + app.AddMiddleware("global", middleware.NewCorsFunc([]string{"www.*.com", "example.com", "127.0.0.1:*"}, map[string]string{ "Access-Control-Allow-Credentials": "true", "Access-Control-Allow-Headers": "Authorization,DNT,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,X-Parent-Id", @@ -112,96 +137,116 @@ example: Cors中间件注册不是全局中间件时,需要最后注册一次Options /*或404方法,否则Options请求匹配了默认404没有经过Cors中间件处理。 -Csrf +# Csrf 校验设置CSRF token 参数: - interface{} 指明获取csrf token的方法,下列是允许使用的值 + + any 指明获取csrf token的方法,下列是允许使用的值 - "csrf" - "query: csrf" - "header: X-CSRF-Token" - "form: csrf" - func(ctx eudore.Context) string {return ctx.Query("csrf")} - nil - interface{} 指明设置Cookie的基础信息,下列是允许使用的值 + any 指明设置Cookie的基础信息,下列是允许使用的值 - "csrf" - http.Cookie{Name: "csrf"} - nil + example: + app.AddMiddleware(middleware.NewCsrfFunc("csrf", nil)) -Dump +# Dump 截取请求信息的中间件,将匹配请求使用webscoket输出给客户端。 参数: + router参数是eudore.Router类型,然后注入拦截路由处理。 + example: + app.AddMiddleware(middleware.NewDumpFunc(app.Group("/eudore/debug"))) -Header +# Header 添加响应Header 参数: + http.Header 需要添加的Header内存 + examaple: + app.AddMiddleware(middleware.NewHeaderFunc(http.Header{ "Cache-Control": []string{"no-cache"}, })) app.AddMiddleware(middleware.NewHeaderWithSecureFunc(nil)) -HeaderFilte +# HeaderFilte 对来源于外部ip请求,过滤指定请求header 参数: + []string 指定内部ip,默认[]string{"10.0.0.0/8", "172.16.0.0/12", "192.0.0.0/24", "127.0.0.1"} []string 指定需要过滤的请求header,默认[]string{HeaderXRealIP, HeaderXForwardedFor, HeaderXForwardedHost, HeaderXForwardedProto, HeaderXRequestID, HeaderXTraceID} + examaple: + app.AddMiddleware(middleware.NewHeaderFilteFunc(nil, nil)) app.AddMiddleware(middleware.NewHeaderFilteFunc([]string{"127.0.0.1"}, nil)) -Logger +# Logger 输出请求access logger并记录相关fields 参数: + eudore.App 指定App对象,需要使用App.Logger输出日志。 ...string 指定额外添加的Params值,如果值非空则会加入到access logger fields中 + example: + app.AddMiddleware(middleware.NewLoggerFunc(app, "route")) -Rate +# Rate 实现请求令牌桶限流 参数: - int 每周期(默认秒)增加speed个令牌 - int 最多拥有的令牌数量 - ...interface{} 额外使用的Options,根据类型来断言设置选项 + + int 每周期(默认秒)增加speed个令牌 + int 最多拥有的令牌数量 + ...any 额外使用的Options,根据类型来断言设置选项 context.Context => 控制cleanupVisitors退出的生命周期 time.Duration => 基础时间周期单位,默认秒 func(eudore.Context) string => 限流获取key的函数,默认Context.ReadIP + example: + // 限流 每秒一个请求,最多保存3个请求 app.AddMiddleware(middleware.NewRateRequestFunc(1, 3, app.Context)) // 限速 每秒32Kb流量,最多保存128Kb流量 app.AddMiddleware(middleware.NewRateSpeedFunc(32*1024, 128*1024, app.Context)) -Recover +# Recover 恢复panic抛出的错误,并输出日志、返回异常响应 example: + app.AddMiddleware(middleware.NewRecoverFunc()) -Referer +# Referer 检查请求Referer Header值是否有效 参数: + map[string]bool 设置referer值是否有效 "" => 其他值未匹配时使用的默认值。 "origin" => 请求Referer和Host同源情况下,检查host为referer前缀,origin检查在其他值检查之前。 @@ -209,7 +254,9 @@ Referer "www.eudore.cn/*" => www.eudore.cn域名全部请求,不指明http或https时为同时包含http和https "www.eudore.cn/api/*" => www.eudore.cn域名全部/api/前缀的请求 "https://www.eudore.cn/*" => www.eudore.cn仅匹配https。 + example: + app.AddMiddleware(middleware.NewRefererFunc(map[string]bool{ "": true, "origin": false, @@ -218,22 +265,28 @@ example: "www.example.com/*": true, })) -RequestID +# RequestID 给请求、响应、日志设置一个请求ID 参数: + func() string 用于创建一个请求ID,默认使用时间戳随机数 + example: + app.AddMiddleware(middleware.NewRequestIDFunc(nil)) -Rewrite +# Rewrite 重写请求路径,需要注册全局中间件 参数: + map[string]string 请求匹配模式对应的目标模式 + example: + app.AddMiddleware("global", middleware.NewRewriteFunc(map[string]string{ "/js/*": "/public/js/$0", "/d/*": "/d/$0-$0", @@ -244,14 +297,17 @@ example: "/help/*": "$0", })) -Router +# Router 用于执行额外的路由匹配行为 参数: - map[string]interface{} 请求路径对应的执行函数,路径前缀不指定方法则为Any方法 + + map[string]any 请求路径对应的执行函数,路径前缀不指定方法则为Any方法 + example: - app.AddMiddleware(middleware.NewRouterFunc(map[string]interface{}{ + + app.AddMiddleware(middleware.NewRouterFunc(map[string]any{ "/api/:v/*": func(ctx eudore.Context) { ctx.Request().URL.Path = "/api/v3/" + ctx.GetParam("*") }, @@ -261,11 +317,12 @@ example: }, })) -RouterRewrite +# RouterRewrite 基于Router中间件实现路由重写,参考Rewrite example: + app.AddMiddleware("global", middleware.NewRouterRewriteFunc(map[string]string{ "/js/*": "/public/js/$0", "/d/*": "/d/$0-$0", @@ -276,11 +333,10 @@ example: "/help/*": "$0", })) -Timeout +# Timeout 设置请求处理超时时间,如果超时返回503状态码并取消context, 实现难点:写入中超时状态码异常、panic栈无法捕捉信息异常、http.Header并发读写、sync.Pool回收了Context、Context数据竟态检测 - */ package middleware // import "github.com/eudore/eudore/middleware" diff --git a/middleware/look.go b/middleware/look.go index f3ebe10..7380093 100644 --- a/middleware/look.go +++ b/middleware/look.go @@ -4,33 +4,28 @@ import ( "encoding/json" "fmt" "reflect" + "strconv" "strings" "text/template" "github.com/eudore/eudore" ) -// 定义LookValue使用的Mime值 -const ( - MimeValueJSON = "value/json" - MimeValueHTML = "value/html" - MimeValueText = "value/text" -) - // NewLookFunc 函数创建一个访问对象数据处理函数。 // -// 如果参数类型为func(eudore.Context) interface{},可以动态返回需要渲染的数据。 +// 如果参数类型为func(eudore.Context) any,可以动态返回需要渲染的数据。 // // 获取请求路由参数"*"为object访问路径,返回object指定属性的数据,允许使用下列参数: -// d=10 depth递归显时最大层数 -// all=false 是否显时非导出属性 -// format=html/json/text 设置数据显示格式 -// godoc=https://golang.org 设置html格式链接的godoc服务地址 -// width=60 设置html格式缩进宽度 -func NewLookFunc(data interface{}) eudore.HandlerFunc { - fn, ok := data.(func(eudore.Context) interface{}) +// +// d=10 depth递归显时最大层数; +// all=false 是否显时非导出属性; +// format=html/json/text 设置数据显示格式; +// godoc=https://golang.org 设置html格式链接的godoc服务地址; +// width=60 设置html格式缩进宽度。 +func NewLookFunc(data any) eudore.HandlerFunc { + fn, ok := data.(func(eudore.Context) any) if !ok { - fn = func(eudore.Context) interface{} { + fn = func(eudore.Context) any { return data } } @@ -41,7 +36,7 @@ func NewLookFunc(data interface{}) eudore.HandlerFunc { } ctx.SetHeader(eudore.HeaderXEudoreAdmin, "look") look := NewLookValue(ctx) - val, err := eudore.GetWithValue(data, strings.Replace(ctx.GetParam("*"), "/", ".", -1), nil, look.ShowAll) + val, err := eudore.GetAnyByPathWithValue(data, strings.ReplaceAll(ctx.GetParam("*"), "/", "."), nil, look.ShowAll) if err != nil { ctx.Fatal(err) return @@ -49,21 +44,17 @@ func NewLookFunc(data interface{}) eudore.HandlerFunc { look.Scan(val) switch getRequestForma(ctx) { - case "json": + case QueryFormatJSON: ctx.SetHeader(eudore.HeaderContentType, eudore.MimeApplicationJSONCharsetUtf8) - encoder := json.NewEncoder(ctx) - if !strings.Contains(ctx.GetHeader(eudore.HeaderAccept), eudore.MimeApplicationJSON) { - encoder.SetIndent("", "\t") - } - encoder.Encode(look) - case "html": - tmpl := getLookTemplate(strings.TrimSuffix(ctx.Path(), "/"), ctx.Querys().Encode(), "view") + _ = eudore.RenderJSON(ctx, look) + case QueryFormatHTML: + tmpl := getLookTemplate(strings.TrimSuffix(ctx.Path(), "/"), ctx.Querys().Encode()) ctx.SetHeader(eudore.HeaderContentType, eudore.MimeTextHTMLCharsetUtf8) - tmpl.ExecuteTemplate(ctx, "view", viewData{ctx.GetParam("*"), eudore.GetStringInt(ctx.GetQuery("width"), 60), look}) + _ = tmpl.ExecuteTemplate(ctx, "view", viewData{ctx.GetParam("*"), eudore.GetAnyByString(ctx.GetQuery("width"), 60), look}) default: - tmpl := getLookTemplate(strings.TrimSuffix(ctx.Path(), "/"), ctx.Querys().Encode(), "text") + tmpl := getLookTemplate(strings.TrimSuffix(ctx.Path(), "/"), ctx.Querys().Encode()) ctx.SetHeader(eudore.HeaderContentType, eudore.MimeTextPlainCharsetUtf8) - tmpl.ExecuteTemplate(ctx, "text", look) + _ = tmpl.ExecuteTemplate(ctx, "text", look) } } } @@ -91,7 +82,7 @@ func NewBindLook(renders map[string]eudore.HandlerDataFunc) eudore.HandlerDataFu } // RenderValueJSON 实现渲染Value为JSON格式。 -func RenderValueJSON(ctx eudore.Context, data interface{}) error { +func RenderValueJSON(ctx eudore.Context, data any) error { look := NewLookValue(ctx) look.Scan(reflect.ValueOf(data)) @@ -107,7 +98,7 @@ func RenderValueJSON(ctx eudore.Context, data interface{}) error { } // RenderValueText 实现渲染Value为Text格式。 -func RenderValueText(ctx eudore.Context, data interface{}) error { +func RenderValueText(ctx eudore.Context, data any) error { look := NewLookValue(ctx) look.Scan(reflect.ValueOf(data)) @@ -116,12 +107,12 @@ func RenderValueText(ctx eudore.Context, data interface{}) error { ctx.SetHeader(eudore.HeaderContentType, eudore.MimeTextPlainCharsetUtf8) } - tmpl := getLookTemplate(strings.TrimSuffix(ctx.Path(), "/"), ctx.Querys().Encode(), "text") + tmpl := getLookTemplate(strings.TrimSuffix(ctx.Path(), "/"), ctx.Querys().Encode()) return tmpl.ExecuteTemplate(ctx, "text", look) } // RenderValueHTML 实现渲染Value为HTML格式。 -func RenderValueHTML(ctx eudore.Context, data interface{}) error { +func RenderValueHTML(ctx eudore.Context, data any) error { look := NewLookValue(ctx) look.Scan(reflect.ValueOf(data)) @@ -129,8 +120,8 @@ func RenderValueHTML(ctx eudore.Context, data interface{}) error { if val := header.Get(eudore.HeaderContentType); len(val) == 0 { ctx.SetHeader(eudore.HeaderContentType, eudore.MimeTextHTMLCharsetUtf8) } - tmpl := getLookTemplate(strings.TrimSuffix(ctx.Path(), "/"), ctx.Querys().Encode(), "view") - return tmpl.ExecuteTemplate(ctx, "view", viewData{ctx.GetParam("*"), eudore.GetStringInt(ctx.GetQuery("width"), 60), look}) + tmpl := getLookTemplate(strings.TrimSuffix(ctx.Path(), "/"), ctx.Querys().Encode()) + return tmpl.ExecuteTemplate(ctx, "view", viewData{ctx.GetParam("*"), eudore.GetAnyByString(ctx.GetQuery("width"), 60), look}) } func getRequestForma(ctx eudore.Context) string { @@ -141,11 +132,11 @@ func getRequestForma(ctx eudore.Context) string { for _, accept := range strings.Split(ctx.GetHeader(eudore.HeaderAccept), ",") { switch strings.TrimSpace(accept) { case eudore.MimeApplicationJSON: - return "json" + return QueryFormatJSON case eudore.MimeTextHTML: - return "html" + return QueryFormatHTML case eudore.MimeTextPlain, eudore.MimeText: - return "text" + return QueryFormatText } } return "" @@ -157,27 +148,27 @@ type viewData struct { Data *LookValue } -func getLookTemplate(path, querys, format string) *template.Template { +func getLookTemplate(path, querys string) *template.Template { depth := 0 paths := []string{path} if querys != "" { querys = "?" + querys } - temp := template.New("look").Funcs(template.FuncMap{ + tpl := template.New("look").Funcs(template.FuncMap{ "addtab": func() string { depth++; return "" }, "subtab": func() string { depth--; return "" }, "gettab": func() string { return strings.Repeat("\t", depth) }, "addpath": func(path string) string { paths = append(paths, path); return "" }, "subpath": func() string { paths = paths[:len(paths)-1]; return "" }, - "getpath": func() string { return fmt.Sprintf("%s%s", strings.Join(paths, "/"), querys) }, - "isnil": func(i interface{}) bool { return reflect.ValueOf(i).IsNil() }, + "getpath": func() string { return strings.Join(paths, "/") + querys }, + "isnil": func(i any) bool { return reflect.ValueOf(i).IsNil() }, "isline": func(i int) bool { return i%16 == 0 }, "showint": func(i string) string { return strings.Repeat(" ", 4-len(i)) + i }, }) for _, i := range lookTemplate.Templates() { - temp.AddParseTree(i.Name(), i.Tree) + _, _ = tpl.AddParseTree(i.Name(), i.Tree) } - return temp + return tpl } var lookTemplate, _ = template.New("look").Funcs(template.FuncMap{ @@ -203,7 +194,7 @@ var lookTemplate, _ = template.New("look").Funcs(template.FuncMap{
{{- template "html" .Data -}}
+ Types of profiles available: @@ -146,46 +135,42 @@ Count Profile Descriptions // HandlerPprofGoroutine 函数处理pprof Goroutine数据,响应format=text/json/html三种格式。 func HandlerPprofGoroutine(ctx eudore.Context) { p := pprof.Lookup("goroutine") - debug := eudore.GetStringInt(ctx.GetQuery("debug")) + debug := eudore.GetAnyByString[int](ctx.GetQuery("debug")) if debug == 0 { ctx.SetHeader(eudore.HeaderContentType, "application/octet-stream") ctx.SetHeader(eudore.HeaderContentDisposition, "attachment; filename=\"goroutine\"") - p.WriteTo(ctx, 0) + _ = p.WriteTo(ctx, 0) return } format := ctx.GetQuery("format") - if format == "text" { + if format == QueryFormatText { ctx.SetHeader(eudore.HeaderContentType, eudore.MimeTextPlainCharsetUtf8) - p.WriteTo(ctx, debug) + _ = p.WriteTo(ctx, debug) return } var buf bytes.Buffer - p.WriteTo(&buf, debug) - var data interface{} + _ = p.WriteTo(&buf, debug) + var data any if debug == 1 { data = newGoroutineDebug1(buf.String()) } else { data = newGoroutineDebug2(buf.String()) } - if format == "json" { + if format == QueryFormatJSON { ctx.SetHeader(eudore.HeaderContentType, eudore.MimeApplicationJSONCharsetUtf8) - encoder := json.NewEncoder(ctx) - if !strings.Contains(ctx.GetHeader(eudore.HeaderAccept), eudore.MimeApplicationJSON) { - encoder.SetIndent("", "\t") - } - encoder.Encode(data) + _ = eudore.RenderJSON(ctx, data) } else { - godoc := eudore.GetString(ctx.GetQuery("godoc"), ctx.GetParam("godoc"), eudore.DefaultGodocServer) + godoc := eudore.GetAnyByString(ctx.GetQuery("godoc"), ctx.GetParam("godoc"), eudore.DefaultGodocServer) godoc = strings.TrimSuffix(godoc, "/") - tmpl, _ := template.New("goroutine").Funcs(template.FuncMap{ + tpl, _ := template.New("goroutine").Funcs(template.FuncMap{ "getPackage": getGodocPackage(godoc), "getSource": getGodocSource(godoc), }).Parse(pprofGoroutineTemplate) ctx.SetHeader(eudore.HeaderContentType, eudore.MimeTextHTMLCharsetUtf8) - tmpl.Execute(ctx, &goroutineData{ + _ = tpl.Execute(ctx, &goroutineData{ Data: data, Debug: debug, }) @@ -193,7 +178,7 @@ func HandlerPprofGoroutine(ctx eudore.Context) { } type goroutineData struct { - Data interface{} + Data any Debug int Godoc string } @@ -228,10 +213,9 @@ type goroutineDebug2Line struct { } func newGoroutineDebug1(str string) []goroutineDebug1Block { - var reg *regexp.Regexp = regexp.MustCompile(`#\t0x(\S+)\t(\S+)\+0x(\S+)(\s+)(\S+):(\d+)`) - var blocks []goroutineDebug1Block - pos := strings.IndexByte(str, '\n') - routines := strings.Split(str[pos+1:], "\n\n") + reg := regexp.MustCompile(`#\t0x(\S+)\t(\S+)\+0x(\S+)(\s+)(\S+):(\d+)`) + routines := strings.Split(str[strings.IndexByte(str, '\n')+1:], "\n\n") + blocks := make([]goroutineDebug1Block, 0, len(routines)) for i := range routines { if routines[i] == "" { continue @@ -255,11 +239,11 @@ func newGoroutineDebug1(str string) []goroutineDebug1Block { func newGoroutineDebug2(str string) []goroutineDebug2Block { reghead := regexp.MustCompile(`goroutine (\d+) \[(.*)\]`) regline := regexp.MustCompile(`\n(\S+)\((.*)\)\n\t(\S+):(\d+)( \+0x\S+)?|\n(created by )(\S+)\n\t(\S+):(\d+) \+0x(\S+)`) - var blocks []goroutineDebug2Block routines := strings.Split(str, "\n\n") + blocks := make([]goroutineDebug2Block, 0, len(routines)) for i := range routines { head := reghead.FindStringSubmatch(routines[i]) - var block = goroutineDebug2Block{Number: head[1], State: head[2]} + block := goroutineDebug2Block{Number: head[1], State: head[2]} matchs := regline.FindAllStringSubmatch(routines[i], -1) for _, m := range matchs { if m[6] != "created by " { diff --git a/middleware/rate.go b/middleware/rate.go index c81cc7b..ebf1b99 100644 --- a/middleware/rate.go +++ b/middleware/rate.go @@ -2,7 +2,6 @@ package middleware import ( "context" - "errors" "io" "sync" "time" @@ -20,8 +19,8 @@ import ( // // time.Duration => 基础时间周期单位,默认秒 // -// func(eudore.Context) string => 限流获取key的函数,默认Context.ReadIP -func NewRateRequestFunc(speed, max int64, options ...interface{}) eudore.HandlerFunc { +// func(eudore.Context) string => 限流获取key的函数,默认Context.ReadIP。 +func NewRateRequestFunc(speed, max int64, options ...any) eudore.HandlerFunc { return newRate(speed, max, options...).HandlerRequest } @@ -32,11 +31,11 @@ func NewRateRequestFunc(speed, max int64, options ...interface{}) eudore.Handler // speed速度不要小于通常Reader的缓冲区大小(最好大于4kB 4096),否则无法请求到住够的令牌导致阻塞。 // // Read时先请求缓冲区大小数量的令牌,然后返还未使用的令牌数量;Write时请求写入数据长度数量的令牌。 -func NewRateSpeedFunc(speed, max int64, options ...interface{}) eudore.HandlerFunc { +func NewRateSpeedFunc(speed, max int64, options ...any) eudore.HandlerFunc { return newRate(speed, max, options...).HandlerSpeed } -func newRate(speed, max int64, options ...interface{}) *rate { +func newRate(speed, max int64, options ...any) *rate { r := &rate{ visitors: make(map[string]*rateBucket), GetKeyFunc: func(ctx eudore.Context) string { @@ -74,12 +73,12 @@ func (r *rate) HandlerRequest(ctx eudore.Context) { func (r *rate) HandlerSpeed(ctx eudore.Context) { rate := r.GetVisitor(r.GetKeyFunc(ctx)) httpctx := ctx.GetContext() - ctx.Request().Body = &rateRequqest{ + ctx.Request().Body = &requqestReaderRate{ ReadCloser: ctx.Request().Body, Context: httpctx, rateBucket: rate, } - ctx.SetResponse(&rateResponse{ + ctx.SetResponse(&responseWriterRate{ ResponseWriter: ctx.Response(), Context: httpctx, rateBucket: rate, @@ -125,22 +124,19 @@ func (r *rate) cleanupVisitors(ctx context.Context) { } } -var errRateReadWaitLong = errors.New("If the github.com/eudore/eudore/middleware speed limit waiting time is too long, it will time out") -var errRateWriteWaitLong = errors.New("If the github.com/eudore/eudore/middleware speed limit waits for write time is too long, it will wait for timeout") - -type rateRequqest struct { +type requqestReaderRate struct { io.ReadCloser context.Context *rateBucket } -type rateResponse struct { +type responseWriterRate struct { eudore.ResponseWriter context.Context *rateBucket } -func (r *rateRequqest) Read(body []byte) (int, error) { +func (r *requqestReaderRate) Read(body []byte) (int, error) { length := len(body) if r.Wait(r.Context, int64(length)) { n, err := r.ReadCloser.Read(body) @@ -151,23 +147,34 @@ func (r *rateRequqest) Read(body []byte) (int, error) { } err := r.Err() if err == nil { - err = errRateReadWaitLong + err = ErrRateReadWaitLong + } + return 0, err +} + +func (r *responseWriterRate) Write(data []byte) (int, error) { + if r.Wait(r.Context, int64(len(data))) { + return r.ResponseWriter.Write(data) + } + err := r.Err() + if err == nil { + err = ErrRateWriteWaitLong } return 0, err } -func (r *rateResponse) Write(body []byte) (int, error) { - if r.Wait(r.Context, int64(len(body))) { - return r.ResponseWriter.Write(body) +func (r *responseWriterRate) WriteString(data string) (int, error) { + if r.Wait(r.Context, int64(len(data))) { + return r.ResponseWriter.WriteString(data) } err := r.Err() if err == nil { - err = errRateWriteWaitLong + err = ErrRateWriteWaitLong } return 0, err } -// rate 定义限流器 +// rate 定义限流器。 type rate struct { mu sync.RWMutex visitors map[string]*rateBucket @@ -193,7 +200,7 @@ func newBucket(speed, max int64) *rateBucket { func (r *rateBucket) Put(n int64) { r.Lock() - r.last = r.last - n*r.speed + r.last -= n * r.speed r.Unlock() } @@ -204,7 +211,7 @@ func (r *rateBucket) Allow(n int64) bool { n = r.last + n*r.speed if n < now { r.last = n - now = now - r.max + now -= r.max if r.last < now { r.last = now } @@ -219,7 +226,7 @@ func (r *rateBucket) Wait(ctx context.Context, n int64) bool { n = r.last + n*r.speed if n < now { r.last = n - now = now - r.max + now -= r.max if r.last < now { r.last = now } diff --git a/middleware/referer.go b/middleware/referer.go index a413ee7..9973254 100644 --- a/middleware/referer.go +++ b/middleware/referer.go @@ -1,25 +1,24 @@ package middleware import ( - "fmt" "strings" "github.com/eudore/eudore" ) -// NewRefererFunc 函数创建Referer header检查中间件。 +// NewRefererFunc 函数创建Referer header检查中间件,如果不指定协议匹配http和https,默认拒绝。 // -// "" => 其他值未匹配时使用的默认值。 +// 阅览器发送Referer值受html meta name referrer和Response Header Referrer-Policy影响。 // // "origin" => 请求Referer和Host同源情况下,检查host为referer前缀,origin检查在其他值检查之前。 // -// "*" => 任意域名端口 +// "*" => 任意域名端口,包含无Referer值。 // -// "www.eudore.cn/*" => www.eudore.cn域名全部请求,不指明http或https时为同时包含http和https +// "www.eudore.cn/*" => www.eudore.cn域名全部请求,不指明http或https时为同时包含http和https。 // // "www.eudore.cn:*/*" => www.eudore.cn任意端口的全部请求,不包含没有指明端口的情况。 // -// "www.eudore.cn/api/*" => www.eudore.cn域名全部/api/前缀的请求 +// "www.eudore.cn/api/*" => www.eudore.cn域名全部/api/前缀的请求。 // // "https://www.eudore.cn/*" => www.eudore.cn仅匹配https。 func NewRefererFunc(data map[string]bool) eudore.HandlerFunc { @@ -28,48 +27,53 @@ func NewRefererFunc(data map[string]bool) eudore.HandlerFunc { tree := new(refererNode) for k, v := range data { - if strings.HasPrefix(k, "http://") || strings.HasPrefix(k, "https://") || k == "" { - tree.insert(k).data = fmt.Sprint(v) + if strings.HasPrefix(k, "http://") || strings.HasPrefix(k, "https://") || k == "" || k == "*" { + tree.insert(k, v) } else { - tree.insert("http://" + k).data = fmt.Sprint(v) - tree.insert("https://" + k).data = fmt.Sprint(v) + tree.insert("http://"+k, v) + tree.insert("https://"+k, v) } } return func(ctx eudore.Context) { referer := ctx.GetHeader(eudore.HeaderReferer) if origin && checkRefererOrigin(ctx, referer) { - if !originvalue { - ctx.WriteHeader(eudore.StatusForbidden) - ctx.WriteString("invalid Referer header " + referer) - ctx.End() + if originvalue { + return + } + } else { + node := tree.matchNode(referer) + if node != nil && node.data { + return } - return - } - node := tree.matchNode(referer) - if node != nil && node.data == "false" { - ctx.WriteHeader(eudore.StatusForbidden) - ctx.WriteString("invalid Referer header " + referer) - ctx.End() } + ctx.WriteHeader(eudore.StatusForbidden) + ctx.WriteString("invalid Referer header " + referer) + ctx.End() } } func checkRefererOrigin(ctx eudore.Context, referer string) bool { - if referer == "" || len(referer) < 8 { + if len(referer) < 8 { return false } - return strings.HasPrefix(referer[7:], ctx.Host()) || strings.HasPrefix(referer[8:], ctx.Host()) + + pos := strings.Index(referer, "://") + if pos != -1 { + referer = referer[pos+3:] + } + return strings.HasPrefix(referer, ctx.Host()) } type refererNode struct { path string + has bool + data bool wildcard *refererNode children []*refererNode - data string } -func (node *refererNode) insert(path string) *refererNode { +func (node *refererNode) insert(path string, data bool) { paths := strings.Split(path, "*") newpaths := make([]string, 1, len(paths)*2-1) newpaths[0] = paths[0] @@ -82,7 +86,8 @@ func (node *refererNode) insert(path string) *refererNode { for _, p := range newpaths { node = node.insertNode(p) } - return node + node.has = true + node.data = data } func (node *refererNode) insertNode(path string) *refererNode { @@ -122,7 +127,7 @@ func (node *refererNode) insertNode(path string) *refererNode { } func (node *refererNode) matchNode(path string) *refererNode { - if path == "" && node.data != "" { + if path == "" && node.has { return node } for _, current := range node.children { @@ -133,16 +138,7 @@ func (node *refererNode) matchNode(path string) *refererNode { } } if node.wildcard != nil { - if node.wildcard.children != nil { - pos := strings.IndexByte(path, '/') - if pos == -1 { - pos = len(path) - } - if result := node.wildcard.matchNode(path[pos:]); result != nil { - return result - } - } - if node.wildcard.data != "" { + if node.wildcard.has { return node.wildcard } } diff --git a/middleware/rewrite.go b/middleware/rewrite.go index d55a099..465aeee 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -172,7 +172,7 @@ func (node *rewriteNode) matchNode(path string, result *matchResult) bool { return false } -// 获取两个字符串的最大公共前缀,返回最大公共前缀和是否拥有最大公共前缀 +// 获取两个字符串的最大公共前缀,返回最大公共前缀和是否拥有最大公共前缀。 func getSubsetPrefix(str1, str2 string) (string, bool) { findSubset := false for i := 0; i < len(str1) && i < len(str2); i++ { @@ -186,7 +186,7 @@ func getSubsetPrefix(str1, str2 string) (string, bool) { if len(str1) > len(str2) { return str2, findSubset } else if len(str1) == len(str2) { - //fix "" not a subset of "" + // fix "" not a subset of "" return str1, str1 == str2 } diff --git a/middleware/router.go b/middleware/router.go index 84c1fc1..95839ab 100644 --- a/middleware/router.go +++ b/middleware/router.go @@ -11,11 +11,11 @@ import ( // NewRouterFunc 函数创建一个路由器中间件,将根据路由路径匹配执行对应的多个处理函数。 // // 如果key为"router",val类型为eudore.Router,则使用改路由器处理请求。 -func NewRouterFunc(data map[string]interface{}) eudore.HandlerFunc { +func NewRouterFunc(data map[string]any) eudore.HandlerFunc { router, ok := data["router"].(eudore.Router) delete(data, "router") if !ok { - router = eudore.NewRouterStd(nil) + router = eudore.NewRouter(nil) router.AddHandler("404", "", eudore.HandlerEmpty) router.AddHandler("405", "", eudore.HandlerEmpty) } @@ -45,7 +45,7 @@ func NewRouterFunc(data map[string]interface{}) eudore.HandlerFunc { // // RouterRewrite中间件使用参数和Rewrite中间件完全相同。 func NewRouterRewriteFunc(data map[string]string) eudore.HandlerFunc { - mapping := make(map[string]interface{}, len(data)) + mapping := make(map[string]any, len(data)) for k, v := range data { k = getRouterRewritePath(k) mapping[k] = newRouterRewriteFunc(v) @@ -67,7 +67,7 @@ func getRouterRewritePath(path string) string { } num++ } else { - str = str + string(path[i]) + str += string(path[i]) } } return str diff --git a/policy/pbac.go b/policy/pbac.go index 3ff16cd..a0f0b23 100644 --- a/policy/pbac.go +++ b/policy/pbac.go @@ -34,8 +34,8 @@ type Policys struct { // Signaturer 定义Policys进行用户信息签名的对象。 type Signaturer interface { - Signed(interface{}) string - Parse(string, interface{}) error + Signed(any) string + Parse(string, any) error } // Member 定义Policy授权对象。 @@ -48,7 +48,7 @@ type Member struct { Policy *Policy `json:"-" alias:"-"` } -// NewPolicys 函数创建默认策略访问控制器 +// NewPolicys 函数创建默认策略访问控制器。 func NewPolicys() *Policys { policys := &Policys{ Signaturer: NewSignaturerJwt([]byte("eudore")), @@ -64,9 +64,12 @@ func NewPolicys() *Policys { return policys } -// HandleHTTP 方法实现eudore.handlerHTTP(handler.go#L49)接口,作为请求处理中间件的处理函数,实现访问控制鉴权。 +// HandleHTTP 方法实现eudore.handlerHTTP(handler.go#L49)接口, +// 作为请求处理中间件的处理函数,实现访问控制鉴权。 // // 请求的param action为空回跳过鉴权方法。 +// +//nolint:cyclop,funlen,gocyclo func (ctl *Policys) HandleHTTP(ctx eudore.Context) { action := ctl.ActionFunc(ctx) if action == "" { @@ -83,8 +86,8 @@ func (ctl *Policys) HandleHTTP(ctx eudore.Context) { } ctx.SetParam(eudore.ParamUserid, fmt.Sprint(userid)) - var now = time.Now() - var datas map[string][]interface{} + now := time.Now() + var datas map[string][]any var names []string // 遍历用户的全部授权的全部stmt @@ -107,7 +110,7 @@ matchPolicys: } names = append(names, p.PolicyName) if datas == nil { - datas = make(map[string][]interface{}) + datas = make(map[string][]any) } for key, val := range s.data { datas[key] = append(datas[key], val...) @@ -141,9 +144,9 @@ func (ctl *Policys) GetMember(userid int) []*Member { } // HandleRuntime 方法返回Policys运行时数据。 -func (ctl *Policys) HandleRuntime(ctx eudore.Context) interface{} { +func (ctl *Policys) HandleRuntime(eudore.Context) any { var policys []Policy - ctl.Policys.Range(func(key, val interface{}) bool { + ctl.Policys.Range(func(key, val any) bool { policys = append(policys, *val.(*Policy)) return true }) @@ -152,13 +155,13 @@ func (ctl *Policys) HandleRuntime(ctx eudore.Context) interface{} { }) members := make(map[int][]*Member) - ctl.Members.Range(func(key, val interface{}) bool { + ctl.Members.Range(func(key, val any) bool { members[key.(int)] = val.([]*Member) return true }) return struct { - Policys []Policy `json:"policys"` - Members interface{} `json:"members"` + Policys []Policy `json:"policys"` + Members any `json:"members"` }{policys, members} } @@ -179,7 +182,7 @@ type forbiddenMessage struct { func (ctl *Policys) handleForbidden(ctx eudore.Context, action, resource, err string) { msg := forbiddenMessage{ - Time: time.Now().Format(eudore.DefaultLoggerTimeFormat), + Time: time.Now().Format(eudore.DefaultLoggerFormatterFormatTime), Host: ctx.Host(), Method: ctx.Method(), Path: ctx.Path(), @@ -194,6 +197,7 @@ func (ctl *Policys) handleForbidden(ctx eudore.Context, action, resource, err st } if ctx.GetParam(eudore.ParamUserid) == "0" { msg.Status = 401 + msg.Error = "" msg.Message = "unauthorized" } ctx.WriteHeader(msg.Status) @@ -206,33 +210,40 @@ const stringBearer = "Bearer " // SignatureUser 定义默认的用户信息,也可以组合该对象使用自定义签名对象。 type SignatureUser struct { // 唯一必要的属性,指定请求的userid - UserID int `json:"userid" alias:"userid"` + UserID int `alias:"user_id" json:"user_id" yaml:"user_id" protobuf:"name=user_id"` + UserName string `alias:"user_name" json:"user_name" yaml:"user_name" protobuf:"name=user_name"` // 如果非空,则为base64([]Statement) Policy string `json:"policy,omitempty" alias:"policy"` Expiration int64 `json:"expiration" alias:"expiration"` } // NewBearer 默认的Bearer签名方法。 -func (ctl *Policys) NewBearer(userid int, policy string, expires int64) string { +func (ctl *Policys) NewBearer(userid int, username, policy string, expires int64) string { return stringBearer + ctl.Signaturer.Signed(&SignatureUser{ UserID: userid, + UserName: username, Policy: base64.StdEncoding.EncodeToString([]byte(policy)), Expiration: expires, }) } -func (ctl *Policys) parseSignatureUser(ctx eudore.Context) (int, error) { +func getBearer(ctx eudore.Context) string { bearer := ctx.GetHeader(eudore.HeaderAuthorization) - if bearer == "" { - return 0, nil + if strings.HasPrefix(bearer, stringBearer) { + return strings.TrimPrefix(bearer, stringBearer) } - if !strings.HasPrefix(bearer, stringBearer) { + return ctx.Request().URL.Query().Get("bearer") +} + +func (ctl *Policys) parseSignatureUser(ctx eudore.Context) (int, error) { + bearer := getBearer(ctx) + if bearer == "" { return 0, nil } // 验证bearer var user SignatureUser - err := ctl.Signaturer.Parse(bearer[7:], &user) + err := ctl.Signaturer.Parse(bearer, &user) if err != nil { return 0, fmt.Errorf("bearer parse error: %s", err.Error()) } @@ -258,6 +269,7 @@ func (ctl *Policys) parseSignatureUser(ctx eudore.Context) (int, error) { for _, s := range statements { if s.MatchAction(action) && s.MatchResource(resource) && s.MatchCondition(ctx) { if s.Effect { + ctx.SetParam(eudore.ParamUsername, user.UserName) return user.UserID, nil } return 0, nil @@ -265,6 +277,7 @@ func (ctl *Policys) parseSignatureUser(ctx eudore.Context) (int, error) { } return 0, nil } + ctx.SetParam(eudore.ParamUsername, user.UserName) return user.UserID, nil } diff --git a/policy/policy.go b/policy/policy.go index 736a771..ca87207 100644 --- a/policy/policy.go +++ b/policy/policy.go @@ -2,6 +2,7 @@ package policy import ( "encoding/json" + "errors" "fmt" "net" "strings" @@ -12,6 +13,10 @@ import ( ) var ( + BearerPrefix = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.` + ErrVerifyTokenInvalid = errors.New("error: incorrect of results from token parsing") + ErrVerifyResultInvalid = errors.New("error:jwt validation error") + // ErrFormatPolcyUnmarshalError 定义策略json解析错误。 ErrFormatPolcyUnmarshalError = "policy unmarshal json error: %v" // ErrFormatDataParseError 定义策略数据解析错误。 @@ -24,9 +29,10 @@ var ( ErrFormatConditionParseError = "policy conditions %s parse %s error: %v" conditionObjects = make(map[string]func() Condition) - dataObjects = make(map[string]func() interface{}) + dataObjects = make(map[string]func() any) ) +//nolint:gochecknoinits func init() { conditionObjects = map[string]func() Condition{ "and": func() Condition { return &conditionAnd{} }, @@ -35,10 +41,11 @@ func init() { "date": func() Condition { return &conditionDate{} }, "time": func() Condition { return &conditionTime{} }, "method": func() Condition { return &conditionMethod{} }, + "path": func() Condition { return &conditionPath{} }, "params": func() Condition { return &conditionParams{} }, } - dataObjects = map[string]func() interface{}{ - "menu": func() interface{} { return new(string) }, + dataObjects = map[string]func() any{ + "menu": func() any { return new(string) }, } } @@ -60,8 +67,8 @@ type Statement struct { Data map[string][]json.RawMessage `json:"data,omitempty"` treeAction *starTree treeResource *starTree - conditions Condition `json:"-"` - data map[string][]interface{} `json:"-"` + conditions Condition `json:"-"` + data map[string][]any `json:"-"` } type _statement Statement @@ -96,7 +103,7 @@ func (stmt Statement) MatchCondition(ctx eudore.Context) bool { } // MatchData 方法返回匹配时的权限数据。 -func (stmt Statement) MatchData() map[string][]interface{} { +func (stmt Statement) MatchData() map[string][]any { return stmt.data } @@ -160,11 +167,16 @@ type conditionTime struct { type conditionMethod struct { Methods []string `json:"methods"` } + +// conditionMethod 定义请求路径条件。 +type conditionPath struct { + Paths []string `json:"paths"` +} type conditionParams map[string][]string // NewConditions 方法解析多个策略条件。 func NewConditions(data map[string]json.RawMessage) ([]Condition, error) { - var conds []Condition + conds := make([]Condition, 0, len(data)) for key, val := range data { fn, ok := conditionObjects[key] if !ok { @@ -181,8 +193,8 @@ func NewConditions(data map[string]json.RawMessage) ([]Condition, error) { return conds, nil } -func newDatas(body map[string][]json.RawMessage) (map[string][]interface{}, error) { - datas := make(map[string][]interface{}) +func newDatas(body map[string][]json.RawMessage) (map[string][]any, error) { + datas := make(map[string][]any) for key, vals := range body { for _, val := range vals { fn, ok := dataObjects[key] @@ -201,7 +213,7 @@ func newDatas(body map[string][]json.RawMessage) (map[string][]interface{}, erro return datas, nil } -// Match conditionAnd +// Match conditionAnd。 func (cond conditionAnd) Match(ctx eudore.Context) bool { for _, i := range cond.Conditions { if !i.Match(ctx) { @@ -210,6 +222,7 @@ func (cond conditionAnd) Match(ctx eudore.Context) bool { } return true } + func (cond *conditionAnd) UnmarshalJSON(body []byte) error { err := json.Unmarshal(body, &cond.Data) if err != nil { @@ -229,6 +242,7 @@ func (cond conditionOr) Match(ctx eudore.Context) bool { } return false } + func (cond *conditionOr) UnmarshalJSON(body []byte) error { err := json.Unmarshal(body, &cond.Data) if err != nil { @@ -248,13 +262,15 @@ func (cond conditionSourceIP) Match(ctx eudore.Context) bool { } return false } + func (cond *conditionSourceIP) UnmarshalJSON(body []byte) error { var strs []string err := json.Unmarshal(body, &strs) if err != nil { return fmt.Errorf(ErrFormatConditionsUnmarshalError, "sourceip", err) } - var ipnets []*net.IPNet + + ipnets := make([]*net.IPNet, 0, len(strs)) for _, i := range strs { if strings.IndexByte(i, '/') == -1 { i += "/32" @@ -270,10 +286,11 @@ func (cond *conditionSourceIP) UnmarshalJSON(body []byte) error { } // Match 方法匹配当前时间范围。 -func (cond conditionDate) Match(ctx eudore.Context) bool { +func (cond conditionDate) Match(eudore.Context) bool { current := time.Now() return current.Before(cond.Before) && current.After(cond.After) } + func (cond *conditionDate) UnmarshalJSON(body []byte) error { var date _conditionDate err := json.Unmarshal(body, &date) @@ -292,11 +309,12 @@ func (cond *conditionDate) UnmarshalJSON(body []byte) error { } // Match 方法匹配当前时间范围。 -func (cond conditionTime) Match(ctx eudore.Context) bool { +func (cond conditionTime) Match(eudore.Context) bool { current := time.Now() current = time.Date(0, 0, 0, current.Hour(), current.Minute(), current.Second(), 0, current.Location()) return current.Before(cond.Before) && current.After(cond.After) } + func (cond *conditionTime) UnmarshalJSON(body []byte) error { var date _conditionDate err := json.Unmarshal(body, &date) @@ -324,10 +342,26 @@ func (cond conditionMethod) Match(ctx eudore.Context) bool { } return false } + func (cond *conditionMethod) UnmarshalJSON(body []byte) error { return json.Unmarshal(body, &cond.Methods) } +// Match 方法匹配http请求路径。 +func (cond conditionPath) Match(ctx eudore.Context) bool { + path := ctx.Path() + for _, i := range cond.Paths { + if i == path { + return true + } + } + return false +} + +func (cond *conditionPath) UnmarshalJSON(body []byte) error { + return json.Unmarshal(body, &cond.Paths) +} + // Match 方法匹配http请求方法。 func (cond conditionParams) Match(ctx eudore.Context) bool { for key, vals := range cond { diff --git a/policy/util.go b/policy/util.go index c6b5bcc..02f606e 100644 --- a/policy/util.go +++ b/policy/util.go @@ -5,10 +5,8 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "errors" "fmt" "strings" - // "github.com/kr/pretty" ) func stringSliceNotIn(strs []string, str string) bool { @@ -20,7 +18,7 @@ func stringSliceNotIn(strs []string, str string) bool { return true } -// ControllerAction 定义生成action参考控制器 +// ControllerAction 定义生成action参考控制器。 type ControllerAction struct{} // ControllerParam 方法定义ControllerAction生成action参数。 @@ -29,14 +27,12 @@ func (ControllerAction) ControllerParam(pkg, name, method string) string { if pos != 0 { pkg = pkg[pos:] } - if strings.HasSuffix(name, "Controller") { - name = name[:len(name)-len("Controller")] - } + name = strings.TrimSuffix(name, "Controller") return fmt.Sprintf("action=%s:%s:%s", pkg, name, method) } -// NewSignaturerJwt 函数创建一个Jwt Signaturer +// NewSignaturerJwt 函数创建一个Jwt Signaturer。 func NewSignaturerJwt(secret []byte) Signaturer { return verifyFunc(func(b []byte) string { h := hmac.New(sha256.New, secret) @@ -47,20 +43,24 @@ func NewSignaturerJwt(secret []byte) Signaturer { type verifyFunc func([]byte) string -func (fn verifyFunc) Signed(claims interface{}) string { - payload, _ := json.Marshal(claims) - var unsigned string = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.` + base64.RawURLEncoding.EncodeToString(payload) +func (fn verifyFunc) Signed(claims any) string { + payload, err := json.Marshal(claims) + if err != nil { + return err.Error() + } + + unsigned := BearerPrefix + base64.RawURLEncoding.EncodeToString(payload) return fmt.Sprintf("%s.%s", unsigned, fn([]byte(unsigned))) } -func (fn verifyFunc) Parse(token string, dst interface{}) error { +func (fn verifyFunc) Parse(token string, dst any) error { parts := strings.Split(token, ".") if len(parts) != 3 { - return errors.New("Error: incorrect # of results from string parsing.") + return ErrVerifyTokenInvalid } if fn([]byte(parts[0]+"."+parts[1])) != parts[2] { - return errors.New("Error:jwt validation error.") + return ErrVerifyResultInvalid } payload, err := base64.RawURLEncoding.DecodeString(parts[1]) @@ -68,18 +68,12 @@ func (fn verifyFunc) Parse(token string, dst interface{}) error { return err } - err = json.Unmarshal(payload, dst) - if err != nil { - return err - } - return nil + return json.Unmarshal(payload, dst) } type starTree struct { - // Index int - Name string - Path string - // Const string + Name string + Path string children []*starTree wildcard *starTree } @@ -98,7 +92,7 @@ func (tree *starTree) Insert(path string) { if pos == -1 { break } - path = strings.Replace(path, "**", "*", -1) + path = strings.ReplaceAll(path, "**", "*") } for i, s := range strings.Split(path, "*") { if i != 0 { @@ -158,10 +152,8 @@ func (tree *starTree) Match(path string) string { break } } - } else { - if tree.Name != "" { - return tree.Name - } + } else if tree.Name != "" { + return tree.Name } if tree.wildcard == nil { diff --git a/router.go b/router.go index 3f443cb..3f94ad3 100644 --- a/router.go +++ b/router.go @@ -1,6 +1,6 @@ package eudore -// Router对象用于定义请求的路由器 +// Router对象用于定义请求的路由器。 import ( "context" @@ -11,6 +11,16 @@ import ( "sync" ) +const ( + routerLoggerAll = "all" + routerLoggerHandler = "handler" + routerLoggerController = "controller" + routerLoggerMiddleware = "middleware" + routerLoggerExtend = "extend" + routerLoggerError = "error" + routerLoggerMetadata = "metadata" +) + /* Router interface is divided into RouterCore and RouterMethod. RouterCore implements router matching algorithm and logic, and RouterMethod implements the encapsulation of routing rule registration. @@ -23,24 +33,26 @@ function extensions, controllers and other behaviors. Do not use the RouterCore method to register routes directly at any time. You should use the Add ... method of RouterMethod. RouterMethod implements the following functions: - Group routing - The middleware or function extension is registered in the local scope/global scope - Add controller - Display routing registration debug information + + Group routing + The middleware or function extension is registered in the local scope/global scope + Add controller + Display routing registration debug information RouterCore has four router cores to implement the following functions: - High performance (70%-90% of httprouter performance, using less memory) - Low code complexity (RouterCoreStd supports 5 levels of priority, a code complexity of 19 is not satisfied) - Request for additional default parameters (including current routing matching rules) - Extend custom routing methods - Variable and wildcard matching - Matching priority Constant > Variable verification > Variable > Wildcard verification > Wildcard - Method priority Specify method > Any method (The specified method will override the Any method, and vice versa) - Variables and wildcards support regular and custom functions to verify data - Variables and wildcards support constant prefix - Get all registered routing rule information (RouterCoreBebug implementation) - Routing rule matching based on Host (implemented by RouterCoreHost) - Allows dynamic addition and deletion of router rules at runtime (RouterCoreStd implementation) + + High performance (70%-90% of httprouter performance, using less memory) + Low code complexity (RouterCoreStd supports 5 levels of priority, a code complexity of 19 is not satisfied) + Request for additional default parameters (including current routing matching rules) + Extend custom routing methods + Variable and wildcard matching + Matching priority Constant > Variable verification > Variable > Wildcard verification > Wildcard + Method priority Specify method > Any method (The specified method will override the Any method, and vice versa) + Variables and wildcards support regular and custom functions to verify data + Variables and wildcards support constant prefix + Get all registered routing rule information + Routing rule matching based on Host (implemented by RouterCoreHost) + Allows dynamic addition and deletion of router rules at runtime (RouterCoreStd implementation) Router 接口分为RouterCore和RouterMethod,RouterCore实现路由器匹配算法和逻辑,RouterMethod实现路由规则注册的封装。 @@ -51,41 +63,43 @@ RouterMethod 路由默认直接注册的接口,设置路由参数、组路由 任何时候请不要使用RouterCore的方法直接注册路由,应该使用RouterMethod的Add...方法。 RouterMethod实现下列功能: - 组路由 - 中间件或函数扩展注册在局部作用域/全局作用域 - 添加控制器 - 显示路由注册debug信息 + + 组路由 + 中间件或函数扩展注册在局部作用域/全局作用域 + 添加控制器 + 显示路由注册debug信息 RouterCore拥有四种路由器核心实现下列功能: - 高性能(httprouter性能的70%-90%,使用更少的内存) - 低代码复杂度(RouterCoreStd支持5级优先级 一处代码复杂度19不满足) - 请求获取额外的默认参数(包含当前路由匹配规则) - 扩展自定义路由方法 - 变量和通配符匹配 - 匹配优先级 常量 > 变量校验 > 变量 > 通配符校验 > 通配符(RouterCoreStd五级优先级) - 方法优先级 指定方法 > Any方法(指定方法会覆盖Any方法,反之不行) - 变量和通配符支持正则和自定义函数进行校验数据 - 变量和通配符支持常量前缀 - 获取注册的全部路由规则信息(RouterCoreBebug实现) - 基于Host进行路由规则匹配(RouterCoreHost实现) - 允许运行时进行动态增删路由器规则(RouterCoreStd实现,外层需要RouterCoreLock包装一层) + + 高性能(httprouter性能的70%-90%,使用更少的内存) + 低代码复杂度(RouterCoreStd支持5级优先级 一处代码复杂度19不满足) + 请求获取额外的默认参数(包含当前路由匹配规则) + 扩展自定义路由方法 + 变量和通配符匹配 + 匹配优先级 常量 > 变量校验 > 变量 > 通配符校验 > 通配符(RouterCoreStd五级优先级) + 方法优先级 指定方法 > Any方法(指定方法会覆盖Any方法,反之不行) + 变量和通配符支持正则和自定义函数进行校验数据 + 变量和通配符支持常量前缀 + 获取注册的全部路由规则信息 + 基于Host进行路由规则匹配(RouterCoreHost实现) + 允许运行时进行动态增删路由器规则(RouterCoreStd实现,外层需要RouterCoreLock包装一层) */ type Router interface { RouterCore // RouterMethod method Group(string) Router Params() *Params - AddHandler(string, string, ...interface{}) error + AddHandler(string, string, ...any) error AddController(...Controller) error - AddMiddleware(...interface{}) error - AddHandlerExtend(...interface{}) error - AnyFunc(string, ...interface{}) - GetFunc(string, ...interface{}) - PostFunc(string, ...interface{}) - PutFunc(string, ...interface{}) - DeleteFunc(string, ...interface{}) - HeadFunc(string, ...interface{}) - PatchFunc(string, ...interface{}) + AddMiddleware(...any) error + AddHandlerExtend(...any) error + AnyFunc(string, ...any) + GetFunc(string, ...any) + PostFunc(string, ...any) + PutFunc(string, ...any) + DeleteFunc(string, ...any) + HeadFunc(string, ...any) + PatchFunc(string, ...any) } // The RouterCore interface performs registration of the route and matches a request and returns the handler. @@ -113,70 +127,76 @@ type RouterStd struct { RouterCore `alias:"routercore"` HandlerExtender `alias:"handlerextender"` Middlewares *middlewareTree `alias:"middlewares"` - logger Logger `alias:"logger"` - params Params `alias:"params"` + GroupParams Params `alias:"params"` + Logger Logger `alias:"logger"` + LoggerKind string `alias:"loggerkind"` + Meta *MetadataRouter `alias:"meta"` } -// HandlerRouter405 function defines the default 405 processing and returns Allow and X-Match-Route Header. -// -// HandlerRouter405 函数定义默认405处理,返回Allow和X-Match-Route Header。 -func HandlerRouter405(ctx Context) { - const page405 string = "405 method not allowed" - ctx.SetHeader(HeaderAllow, ctx.GetParam(ParamAllow)) - ctx.SetHeader(HeaderXEudoreRoute, ctx.GetParam(ParamRoute)) - ctx.WriteHeader(405) - ctx.Render(page405) -} - -// HandlerRouter404 function defines the default 404 processing. -// -// HandlerRouter404 函数定义默认404处理。 -func HandlerRouter404(ctx Context) { - const page404 string = "404 page not found" - ctx.WriteHeader(404) - ctx.Render(page404) +type MetadataRouter struct { + Health bool `alias:"health" json:"health" xml:"health" yaml:"health"` + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` + Core any `alias:"core" json:"core" xml:"core" yaml:"core"` + Errors []string `alias:"errors,omitempty" json:"errors,omitempty" xml:"errors,omitempty" yaml:"errors,omitempty"` + Methods []string `alias:"methods" json:"methods" xml:"methods" yaml:"methods"` + Paths []string `alias:"paths" json:"paths" xml:"paths" yaml:"paths"` + Params []Params `alias:"params" json:"params" xml:"params" yaml:"params"` + HandlerNames [][]string `alias:"handlernames" json:"handlernames" xml:"handlernames" yaml:"handlernames"` } -// NewRouterStd method uses a RouterCore to create a Router object. +// NewRouter method uses a RouterCore to create a Router object. // // RouterStd implements RouterMethod interface registration related details, and routing matching is implemented by RouterCore. // -// NewRouterStd 方法使用一个RouterCore创建Router对象。 +// NewRouter 方法使用一个RouterCore创建Router对象。 // -// RouterStd实现RouterMethod接口注册相关细节,路由匹配由RouterCore实现。 -func NewRouterStd(core RouterCore) Router { +// Router实现RouterMethod接口注册相关细节,路由匹配由RouterCore实现。 +func NewRouter(core RouterCore) Router { if core == nil { core = NewRouterCoreStd() } return &RouterStd{ RouterCore: core, - params: Params{ParamRoute, ""}, - HandlerExtender: NewHandlerExtendWarp(NewHandlerExtendTree(), DefaultHandlerExtend), + HandlerExtender: NewHandlerExtenderWarp(NewHandlerExtenderTree(), DefaultHandlerExtender), Middlewares: newMiddlewareTree(), - logger: DefaultLoggerNull, + GroupParams: Params{ParamRoute, ""}, + Logger: DefaultLoggerNull, + LoggerKind: DefaultRouterLoggerKind, + Meta: &MetadataRouter{Name: "eudore.RouterStd"}, } } // Mount 方法使RouterStd挂载上下文,上下文传递给RouterCore。 // -// 并从ctx.Value(ContextKeyApp)获取Logger,初始化RouterStd日志输出函数。 +// 从ctx.Value(ContextKeyApp)获取Logger,初始化RouterStd日志输出函数。 +// +// 从ctx.Value(ContextKeyHandlerExtender)获取HandlerExtender,替换DefaultHandlerExtender。 func (r *RouterStd) Mount(ctx context.Context) { log, ok := ctx.Value(ContextKeyApp).(Logger) if ok { - r.logger = log + r.Logger = log + } + he, ok := ctx.Value(ContextKeyHandlerExtender).(HandlerExtender) + if ok { + r.HandlerExtender = NewHandlerExtenderWarp(NewHandlerExtenderTree(), he) } - withMount(ctx, r.RouterCore) + anyMount(ctx, r.RouterCore) } // Unmount 方法使RouterStd卸载上下文,上下文传递给RouterCore。 func (r *RouterStd) Unmount(ctx context.Context) { - withUnmount(ctx, r.RouterCore) - r.logger = DefaultLoggerNull + anyUnmount(ctx, r.RouterCore) + r.Logger = DefaultLoggerNull } // Metadata 方法返回RouterCore的Metadata。 -func (r *RouterStd) Metadata() interface{} { - return withMetadata(r.RouterCore) +func (r *RouterStd) Metadata() any { + r.Meta.Health = len(r.Meta.Errors) == 0 + r.Meta.Core = anyMetadata(r.RouterCore) + if r.Meta.Core == nil { + r.Meta.Core = fmt.Sprintf("%T", r.RouterCore) + } + return *r.Meta } // Group method returns a new group router. @@ -204,13 +224,23 @@ func (r *RouterStd) Metadata() interface{} { // 最顶级HandlerExtender对象为defaultHandlerExtend, // 可以使用RegisterHandlerExtend函数和NewHandlerFuncs函数调用defaultHandlerExtend对象。 func (r *RouterStd) Group(path string) Router { + params := NewParamsRoute(path) + kind := params.Get(ParamLoggerKind) + if kind != "" { + params.Del(ParamLoggerKind) + } else { + kind = r.LoggerKind + } + // 构建新的路由方法配置器 return &RouterStd{ RouterCore: r.RouterCore, - HandlerExtender: NewHandlerExtendWarp(NewHandlerExtendTree(), r.HandlerExtender), + HandlerExtender: NewHandlerExtenderWarp(NewHandlerExtenderTree(), r.HandlerExtender), Middlewares: r.Middlewares.clone(), - logger: r.logger, - params: r.params.Clone().CombineWithRoute(NewParamsRoute(path)), + Logger: r.Logger, + LoggerKind: kind, + GroupParams: r.GroupParams.Clone().CombineWithRoute(params), + Meta: r.Meta, } } @@ -218,13 +248,12 @@ func (r *RouterStd) Group(path string) Router { // // Params 方法返回当前路由参数,路由参数值为空字符串不会被使用。 func (r *RouterStd) Params() *Params { - return &r.params + return &r.GroupParams } // getRoutePath 函数截取到路径中的route,支持'{}'进行块匹配。 func getRoutePath(path string) string { - var depth = 0 - var str = "" + depth, str := 0, "" for i := range path { switch path[i] { case '{': @@ -285,81 +314,77 @@ func getRouteParam(path, key string) string { // 如果当前Router无法处理,则调用上一级group的HandlerExtender或defaultHandlerExtend处理,全部无法处理则输出error日志。 // // 中间件数据会根据当前路由路径从数据中匹配,然后将请求处理函数附加到处理函数之前。 -// -func (r *RouterStd) AddHandler(method, path string, hs ...interface{}) error { - return r.registerHandlers(method, path, hs...) +func (r *RouterStd) AddHandler(method, path string, hs ...any) error { + return r.addHandler(strings.ToUpper(method), path, hs...) } -// registerHandlers 方法将handler转换成HandlerFuncs,添加路由路径对应的请求中间件,并调用RouterCore对象注册路由方法。 -func (r *RouterStd) registerHandlers(method, path string, hs ...interface{}) (err error) { +// addHandler 方法将handler转换成HandlerFuncs,添加路由路径对应的请求中间件,并调用RouterCore对象注册路由方法。 +func (r *RouterStd) addHandler(method, path string, hs ...any) (err error) { defer func() { // RouterCoreStd 注册未知校验规则存在panic,或者其他自定义路由注册出现panic。 if rerr := recover(); rerr != nil { - err = fmt.Errorf(ErrFormatRouterStdRegisterHandlersRecover, method, path, rerr) - r.logger.WithField("depth", "stack").WithField("params", r.params).Error(err) + err = fmt.Errorf(ErrFormatRouterStdAddHandlerRecover, method, path, rerr) + r.getLoggerError(err, 0).WithField("depth", "stack").Error(err) } }() - params := r.params.Clone().CombineWithRoute(NewParamsRoute(path)) + depth := getDepthWithFunc(2, 8, ".AddController") + params := r.GroupParams.Clone().CombineWithRoute(NewParamsRoute(path)) path = params.Get("route") fullpath := params.String() // 如果方法为404、405方法,route为空 if len(fullpath) > 6 && fullpath[:6] == "route=" { fullpath = fullpath[6:] } - method = strings.ToUpper(method) - handlers, err := r.newHandlerFuncs(path, hs) + handlers, err := r.newHandlerFuncs(path, hs, depth+1) if err != nil { - r.logger.WithField("depth", getContrllerDepth()).WithField("params", r.params).Error(err) return err } + // 如果注册方法是TEST则输出RouterStd debug信息 if method == "TEST" { - r.logger.WithField("depth", getContrllerDepth()).Debugf( + r.getLogger(routerLoggerHandler, depth).Debugf( "Test handlers params is %s, split path to: ['%s'], match middlewares is: %v, register handlers is: %v.", params.String(), strings.Join(getSplitPath(path), "', '"), r.Middlewares.Lookup(path), handlers, ) - return + return nil } - r.logger.WithField("depth", getContrllerDepth()).Info("Register handler:", + r.getLogger(routerLoggerHandler, depth).Info("Register handler:", method, strings.TrimPrefix(params.String(), "route="), handlers) if handlers != nil { handlers = NewHandlerFuncsCombine(r.Middlewares.Lookup(path), handlers) } // 处理多方法 - var errs errormulit + var errs mulitError for _, method := range strings.Split(method, ",") { method = strings.TrimSpace(method) if checkMethod(method) { r.RouterCore.HandleFunc(method, fullpath, handlers) + if r.getLogger(routerLoggerMetadata, 0) != DefaultLoggerNull { + r.Meta.addHandler(method, path, handlers) + } } else { - err := fmt.Errorf(ErrFormatRouterStdRegisterHandlersMethodInvalid, method, method, fullpath) + err := fmt.Errorf(ErrFormatRouterStdAddHandlerMethodInvalid, method, fullpath) errs.HandleError(err) - r.logger.WithField("depth", getContrllerDepth()).WithField("params", r.params).Error(err) + r.getLoggerError(err, depth).Error(err) } } return errs.Unwrap() } -func getContrllerDepth() int { - pc := make([]uintptr, 10) - n := runtime.Callers(2, pc) - if n > 0 { - index := 1 - frames := runtime.CallersFrames(pc[:n]) - frame, more := frames.Next() - for more { - if strings.HasSuffix(frame.Function, ".AddController") { - return index - } - - index++ - frame, more = frames.Next() +func checkMethod(method string) bool { + switch method { + case "ANY", "404", "405", "NotFound", "MethodNotAllowed": + return true + } + for _, allMethod := range DefaultRouterAllMethod { + if allMethod == method { + return true } } - return 2 + return false } // The newHandlerFuncs method creates HandlerFuncs based on the path and multiple parameters. @@ -370,57 +395,47 @@ func getContrllerDepth() int { // newHandlerFuncs 方法根据路径和多个参数创建HandlerFuncs。 // // RouterStd先调用当前HandlerExtender.NewHandlerFuncs创建多个函数处理者,如果返回空会从上级HandlerExtender创建。 -func (r *RouterStd) newHandlerFuncs(path string, handlers []interface{}) (HandlerFuncs, error) { +func (r *RouterStd) newHandlerFuncs(path string, handlers []any, depth int) (HandlerFuncs, error) { var hs HandlerFuncs - var errs errormulit + var errs mulitError // 转换处理函数 for i, fn := range handlers { - handler := r.HandlerExtender.NewHandlerFuncs(path, fn) - if handler != nil && len(handler) > 0 { + handler := r.HandlerExtender.CreateHandler(path, fn) + if len(handler) > 0 { hs = NewHandlerFuncsCombine(hs, handler) } else { - errs.HandleError(fmt.Errorf(ErrFormatRouterStdNewHandlerFuncsUnregisterType, path, i, reflect.TypeOf(fn).String())) + err := fmt.Errorf(ErrFormatRouterStdNewHandlerFuncsUnregisterType, path, i, reflect.TypeOf(fn).String()) + errs.HandleError(err) + r.getLoggerError(err, depth).Error(err) } } return hs, errs.Unwrap() } -func checkMethod(method string) bool { - switch method { - case "ANY", "404", "405", "NotFound", "MethodNotAllowed": - return true - } - for _, allMethod := range DefaultRouterAllMethod { - if allMethod == method { - return true - } - } - return false -} - // AddController method registers the controller, and the controller determines the routing registration behavior. // // AddController 方法注册控制器,由控制器决定路由注册行为。 func (r *RouterStd) AddController(controllers ...Controller) error { - var errs errormulit + var errs mulitError for _, controller := range controllers { + route := strings.TrimPrefix(r.GroupParams.String(), "route=") name := getControllerPathName(controller) - r.logger.WithField("depth", 1).Info("Register controller:", r.params.String(), name) + r.getLogger(routerLoggerController, 1).Info("Register controller:", route, name) err := controller.Inject(controller, r) if err != nil { err = fmt.Errorf(ErrFormatRouterStdAddController, name, err) errs.HandleError(err) - r.logger.WithField("depth", 1).WithField("params", r.params).Error(err) + r.getLoggerError(err, 1).Error(err) } } return errs.Unwrap() } -// getControllerPathName 函数获取控制器的名称 +// getControllerPathName 函数获取控制器的名称。 func getControllerPathName(ctl Controller) string { - ster, ok := ctl.(controllerName) + u, ok := ctl.(interface{ Unwrap() Controller }) if ok { - return ster.ControllerName() + ctl = u.Unwrap() } cType := reflect.Indirect(reflect.ValueOf(ctl)).Type() return fmt.Sprintf("%s.%s", cType.PkgPath(), cType.Name()) @@ -434,40 +449,32 @@ func getControllerPathName(ctl Controller) string { // AddMiddleware 给路由器添加多个中间件函数,会使用HandlerExtender转换参数。 // // 如果参数数量大于1且第一个参数为字符串类型,会将第一个字符串类型参数作为添加中间件的路径。 -func (r *RouterStd) AddMiddleware(hs ...interface{}) error { +func (r *RouterStd) AddMiddleware(hs ...any) error { if len(hs) == 0 { return nil } - path := r.params.Get("route") + depth := getDepthWithFunc(1, 4, "(*App).AddMiddleware") + path := r.GroupParams.Get("route") if len(hs) > 1 { route, ok := hs[0].(string) if ok { - path = path + route + path += route hs = hs[1:] } } - handlers, err := r.newHandlerFuncs(path, hs) + handlers, err := r.newHandlerFuncs(path, hs, depth+1) if err != nil { - r.logger.WithField("depth", getMiddlewareDepath()).WithField("params", r.params).Error(err) return err } r.Middlewares.Insert(path, handlers) r.RouterCore.HandleFunc("Middlewares", path, handlers) - r.logger.WithField("depth", getMiddlewareDepath()).Info("Register middleware:", path, handlers) + r.getLogger(routerLoggerMiddleware, depth).Info("Register middleware:", path, handlers) return nil } -func getMiddlewareDepath() int { - ptr, _, _, ok := runtime.Caller(2) - if ok && strings.HasSuffix(runtime.FuncForPC(ptr).Name(), ".(*App).AddMiddleware") { - return 2 - } - return 1 -} - // AddHandlerExtend method adds an extension function to the current Router. // // If the number of parameters is greater than 1 and the first parameter is a string type, @@ -476,32 +483,32 @@ func getMiddlewareDepath() int { // AddHandlerExtend 方法给当前Router添加扩展函数。 // // 如果参数数量大于1且第一个参数为字符串类型,会将第一个字符串类型参数作为添加扩展函数的路径。 -func (r *RouterStd) AddHandlerExtend(handlers ...interface{}) error { +func (r *RouterStd) AddHandlerExtend(handlers ...any) error { if len(handlers) == 0 { return nil } - path := r.params.Get("route") + path := r.GroupParams.Get("route") if len(handlers) > 1 { route, ok := handlers[0].(string) if ok { - path = path + route + path += route handlers = handlers[1:] } } - var errs errormulit + var errs mulitError for _, handler := range handlers { - err := r.HandlerExtender.RegisterHandlerExtend(path, handler) + err := r.HandlerExtender.RegisterExtender(path, handler) if err != nil { - err = fmt.Errorf(ErrFormatRouterStdAddHandlerExtend, path, err) + err = fmt.Errorf(ErrFormatRouterStdAddHandlerExtender, path, err) errs.HandleError(err) - r.logger.WithField("depth", 1).WithField("params", r.params).Error(err) + r.getLoggerError(err, 1).Error(err) } else { - iValue := reflect.ValueOf(handler) - if iValue.Kind() == reflect.Func { - r.logger.WithField("depth", 1).Info("Register extend:", - runtime.FuncForPC(iValue.Pointer()).Name(), iValue.Type().In(0).String()) + v := reflect.ValueOf(handler) + if v.Kind() == reflect.Func { + name := runtime.FuncForPC(v.Pointer()).Name() + r.getLogger(routerLoggerExtend, 1).Info("Register extend:", name, v.Type().In(0).String()) } } } @@ -518,48 +525,109 @@ func (r *RouterStd) AddHandlerExtend(handlers ...interface{}) error { // // Any方法注册的路由规则会被指定方法注册覆盖,反之不行。 // Any默认注册方法包含Get Post Put Delete Head Patch六种,定义在全局变量RouterAnyMethod。 -func (r *RouterStd) AnyFunc(path string, h ...interface{}) { - r.registerHandlers(MethodAny, path, h...) +func (r *RouterStd) AnyFunc(path string, h ...any) { + _ = r.addHandler(MethodAny, path, h...) } // GetFunc 方法实现注册一个Get方法的http请求处理函数。 -func (r *RouterStd) GetFunc(path string, h ...interface{}) { - r.registerHandlers(MethodGet, path, h...) +func (r *RouterStd) GetFunc(path string, h ...any) { + _ = r.addHandler(MethodGet, path, h...) } // PostFunc 方法实现注册一个Post方法的http请求处理函数。 -func (r *RouterStd) PostFunc(path string, h ...interface{}) { - r.registerHandlers(MethodPost, path, h...) +func (r *RouterStd) PostFunc(path string, h ...any) { + _ = r.addHandler(MethodPost, path, h...) } // PutFunc 方法实现注册一个Put方法的http请求处理函数。 -func (r *RouterStd) PutFunc(path string, h ...interface{}) { - r.registerHandlers(MethodPut, path, h...) +func (r *RouterStd) PutFunc(path string, h ...any) { + _ = r.addHandler(MethodPut, path, h...) } // DeleteFunc 方法实现注册一个Delete方法的http请求处理函数。 -func (r *RouterStd) DeleteFunc(path string, h ...interface{}) { - r.registerHandlers(MethodDelete, path, h...) +func (r *RouterStd) DeleteFunc(path string, h ...any) { + _ = r.addHandler(MethodDelete, path, h...) } // HeadFunc 方法实现注册一个Head方法的http请求处理函数。 -func (r *RouterStd) HeadFunc(path string, h ...interface{}) { - r.registerHandlers(MethodHead, path, h...) +func (r *RouterStd) HeadFunc(path string, h ...any) { + _ = r.addHandler(MethodHead, path, h...) } // PatchFunc 方法实现注册一个Patch方法的http请求处理函数。 -func (r *RouterStd) PatchFunc(path string, h ...interface{}) { - r.registerHandlers(MethodPatch, path, h...) +func (r *RouterStd) PatchFunc(path string, h ...any) { + _ = r.addHandler(MethodPatch, path, h...) } -// middlewareTree 定义中间件信息存储树 +func (r *RouterStd) getLogger(kind string, depth int) Logger { + if strings.Contains(r.LoggerKind, kind) || strings.Contains(r.LoggerKind, routerLoggerAll) { + if depth > 0 { + return r.Logger.WithField(ParamDepth, depth) + } + return r.Logger + } + return DefaultLoggerNull +} + +func (r *RouterStd) getLoggerError(err error, depth int) Logger { + r.Meta.Errors = append(r.Meta.Errors, err.Error()) + return r.getLogger(routerLoggerError, depth) +} + +func getDepthWithFunc(start, size int, fn string) int { + pc := make([]uintptr, size) + n := runtime.Callers(start+1, pc) + if n > 0 { + index := start + frames := runtime.CallersFrames(pc[:n]) + frame, more := frames.Next() + for more { + if strings.HasSuffix(frame.Function, fn) { + return index + } + + index++ + frame, more = frames.Next() + } + } + return start +} + +// addHandler 方法保持添加的路由信息。 +func (r *MetadataRouter) addHandler(method, path string, handlers HandlerFuncs) { + // 删除记录的路由信息 + if getRouteParam(path, ParamRegister) == "off" || handlers == nil { + path = getRoutePath(path) + for i := range r.Methods { + if r.Paths[i] == path && r.Methods[i] == method { + r.Methods = r.Methods[:i+copy(r.Methods[i:], r.Methods[i+1:])] + r.Paths = r.Paths[:i+copy(r.Paths[i:], r.Paths[i+1:])] + r.Params = r.Params[:i+copy(r.Params[i:], r.Params[i+1:])] + r.HandlerNames = r.HandlerNames[:i+copy(r.HandlerNames[i:], r.HandlerNames[i+1:])] + break + } + } + return + } + + names := make([]string, len(handlers)) + for i := range handlers { + names[i] = fmt.Sprint(handlers[i]) + } + r.Methods = append(r.Methods, method) + r.Paths = append(r.Paths, getRoutePath(path)) + r.Params = append(r.Params, NewParamsRoute(path)) + r.HandlerNames = append(r.HandlerNames, names) +} + +// middlewareTree 定义中间件信息存储树。 type middlewareTree struct { index int node *middlewareNode } func newMiddlewareTree() *middlewareTree { - return &middlewareTree{node: new(middlewareNode)} + return &middlewareTree{node: &middlewareNode{}} } func (t *middlewareTree) Insert(path string, val HandlerFuncs) { @@ -571,7 +639,7 @@ func (t *middlewareTree) Insert(path string, val HandlerFuncs) { t.node.Insert(path, indexs, val) } -// Lookup 方法查找路径对应的处理函数,并安装索引进行排序。 +// Lookup 方法查找路径对应的处理函数,并按照索引进行排序。 func (t *middlewareTree) Lookup(path string) HandlerFuncs { indexs, vals := t.node.Lookup(path) length := len(vals) @@ -627,7 +695,7 @@ func (t *middlewareNode) Insert(path string, indexs []int, vals HandlerFuncs) { t.childs = append(t.childs, &middlewareNode{path: path, indexs: indexs, vals: vals}) } -// Lookup Find if seachKey exist in current trie tree and return its value +// Lookup Find if seachKey exist in current trie tree and return its value. func (t *middlewareNode) Lookup(path string) ([]int, HandlerFuncs) { for _, i := range t.childs { if strings.HasPrefix(path, i.path) { @@ -638,7 +706,7 @@ func (t *middlewareNode) Lookup(path string) ([]int, HandlerFuncs) { return t.indexs, t.vals } -// clone 方法深拷贝这个中间件存储节点 +// clone 方法深拷贝这个中间件存储节点。 func (t *middlewareNode) clone() *middlewareNode { nt := *t for i := range nt.childs { @@ -647,7 +715,7 @@ func (t *middlewareNode) clone() *middlewareNode { return &nt } -// indexsCombine 函数合并两个int切片 +// indexsCombine 函数合并两个int切片。 func indexsCombine(hs1, hs2 []int) []int { // if nil if len(hs1) == 0 { @@ -664,8 +732,8 @@ func indexsCombine(hs1, hs2 []int) []int { // // routerCoreLock 允许对RouterCore读写进行加锁,用于运行时动态增删路由规则。 type routerCoreLock struct { - sync.RWMutex RouterCore + sync.RWMutex } // NewRouterCoreLock function creates a router core with a read-write lock, @@ -690,99 +758,13 @@ func (r *routerCoreLock) HandleFunc(method, path string, hs HandlerFuncs) { } // Match 方法对路由器加读锁进行匹配请求。 -func (r *routerCoreLock) Match(method, path string, params *Params) (hs HandlerFuncs) { +func (r *routerCoreLock) Match(method, path string, params *Params) HandlerFuncs { r.RLock() - hs = r.RouterCore.Match(method, path, params) - r.RUnlock() - return -} - -// routerCoreDebug 定义debug路由器。 -type routerCoreDebug struct { - RouterCore `json:"-" xml:"-"` - Methods []string `json:"methods" xml:"methods"` - Paths []string `json:"paths" xml:"paths"` - Params []Params `json:"params" xml:"params"` - HandlerNames [][]string `json:"handlernames" xml:"handlernames"` -} - -type routerCoreMetadata struct { - Health bool `json:"health" xml:"health"` - Name string `json:"name" xml:"name"` - Methods []string `json:"methods" xml:"methods"` - Paths []string `json:"paths" xml:"paths"` - Params []Params `json:"params" xml:"params"` - HandlerNames [][]string `json:"handlernames" xml:"handlernames"` -} - -// NewRouterCoreDebug function specifies the routing core to create a debug core, -// using eudore.RouterCoreStd as the core by default. -// -// Visit GET /eudore/debug/router/data to get router registration information. -// -// NewRouterCoreDebug 函数指定路由核心创建一个debug核心,默认使用eudore.RouterCoreStd为核心。 -// -// 访问 GET /eudore/debug/router/data 可以获取路由器注册信息。 -func NewRouterCoreDebug(core RouterCore) RouterCore { - if core == nil { - core = NewRouterCoreStd() - } - r := &routerCoreDebug{ - RouterCore: core, - } - r.HandleFunc("GET", "/eudore/debug/router/data", HandlerFuncs{r.HandleHTTP}) - return r -} - -// HandleFunc implements the eudore.RouterCore interface and records all routing information. -// -// HandleFunc 实现eudore.RouterCore接口,记录全部路由信息。 -func (r *routerCoreDebug) HandleFunc(method, path string, handlers HandlerFuncs) { - r.RouterCore.HandleFunc(method, path, handlers) - // 删除记录的路由信息 - if getRouteParam(path, ParamRegister) == "off" || handlers == nil { - path = getRoutePath(path) - for i := range r.Methods { - if r.Paths[i] == path && r.Methods[i] == method { - r.Methods = r.Methods[:i+copy(r.Methods[i:], r.Methods[i+1:])] - r.Paths = r.Paths[:i+copy(r.Paths[i:], r.Paths[i+1:])] - r.Params = r.Params[:i+copy(r.Params[i:], r.Params[i+1:])] - r.HandlerNames = r.HandlerNames[:i+copy(r.HandlerNames[i:], r.HandlerNames[i+1:])] - break - } - } - return - } - - names := make([]string, len(handlers)) - for i := range handlers { - names[i] = fmt.Sprint(handlers[i]) - } - r.Methods = append(r.Methods, method) - r.Paths = append(r.Paths, getRoutePath(path)) - r.Params = append(r.Params, NewParamsRoute(path)) - r.HandlerNames = append(r.HandlerNames, names) -} - -// Metadata 方法返回routerCoreDebug记录的路由信息 -func (r *routerCoreDebug) Metadata() interface{} { - return routerCoreMetadata{ - Health: true, - Name: "eudore.routerCoreDebug", - Methods: r.Methods, - Paths: r.Paths, - Params: r.Params, - HandlerNames: r.HandlerNames, - } -} - -// HandleHTTP 方法返回debug路由信息数据。 -func (r *routerCoreDebug) HandleHTTP(ctx Context) { - ctx.SetHeader(HeaderXEudoreAdmin, "router-debug") - ctx.Render(r.Metadata()) + defer r.RUnlock() // if valid func panic + return r.RouterCore.Match(method, path, params) } -// routerCoreHost 实现基于host进行路由匹配 +// routerCoreHost 实现基于host进行路由匹配。 type routerCoreHost struct { routertree routerHostNode routers map[string]RouterCore @@ -819,7 +801,7 @@ func NewRouterCoreHost(fn func(string) RouterCore) RouterCore { // If the host value is empty and registered to the router core of'*', // multiple hosts are allowed to use',' to divide the registration to multiple hosts at once. // -// HandleFunc 方法从path中寻找host参数选择路由器注册匹配 +// # HandleFunc 方法从path中寻找host参数选择路由器注册匹配 // // host值为一个host模式,允许存在*,表示当前任意字符到下一个'.'或结尾。 // @@ -853,12 +835,13 @@ func (r *routerCoreHost) getRouterCore(host string) RouterCore { } // Match 方法返回routerCoreHost.matchHost函数处理请求,在matchHost函数中使用host值进行二次匹配并拼接请求处理函数。 -func (r *routerCoreHost) Match(method, path string, params *Params) HandlerFuncs { +func (r *routerCoreHost) Match(string, string, *Params) HandlerFuncs { return HandlerFuncs{r.matchHost} } func (r *routerCoreHost) matchHost(ctx Context) { - hs := r.routertree.matchNode(split2byte(ctx.Host(), ':')).Match(ctx.Method(), ctx.Path(), ctx.Params()) + host, port, _ := strings.Cut(ctx.Host(), ":") + hs := r.routertree.matchNode(host, port).Match(ctx.Method(), ctx.Path(), ctx.Params()) index, handlers := ctx.GetHandler() ctx.SetHandler(index, NewHandlerFuncsCombine(NewHandlerFuncsCombine(handlers[:index+1], hs), handlers[index+1:])) } @@ -891,7 +874,7 @@ func (node *routerHostNode) getRouter(port string) RouterCore { } func (node *routerHostNode) insert(path string, val RouterCore) { - host, port := split2byte(path, ':') + host, port, _ := strings.Cut(path, ":") paths := strings.Split(host, "*") newpaths := make([]string, 1, len(paths)*2-1) newpaths[0] = paths[0] diff --git a/routerstd.go b/routerstd.go index 8b12c9a..2a18add 100644 --- a/routerstd.go +++ b/routerstd.go @@ -82,7 +82,7 @@ func (r *routerCoreStd) Mount(ctx context.Context) { // The router matches the handlers available to the current path from // the middleware tree and adds them to the front of the handler. // -// HandleFunc 给路由器注册一个新的方法请求路径 +// HandleFunc 给路由器注册一个新的方法请求路径。 // // 路由器会从中间件树中匹配当前路径可使用的处理者,并添加到处理者前方。 func (r *routerCoreStd) HandleFunc(method string, path string, handler HandlerFuncs) { @@ -116,7 +116,7 @@ func (r *routerCoreStd) Match(method, path string, params *Params) HandlerFuncs return r.handler404 } // default method - for i, m := range defaultRouterAnyMethod { + for i, m := range DefaultRouterCoreMethod { if m == method { if node.handlers[i] != nil { *params = params.Set(ParamRoute, node.params[i][1]).Add(node.params[i][2:]...) @@ -141,7 +141,7 @@ func (r *routerCoreStd) Match(method, path string, params *Params) HandlerFuncs // // 添加一个新的路由Node。 func (r *routerCoreStd) insertRoute(method, path string, val HandlerFuncs) { - var currentNode = r.root + currentNode := r.root params := NewParamsRoute(path) if params.Get(ParamRegister) == "off" || val == nil { @@ -196,7 +196,7 @@ func (r *routerCoreStd) newStdNode(path string) *stdNode { func (r *routerCoreStd) loadCheckFunc(path string) (string, func(string) bool) { path = path[1:] // 截取参数名称和校验函数名称 - name, fname := split2byte(path, '|') + name, fname, _ := strings.Cut(path, "|") if name == "" || fname == "" { return "", nil } @@ -206,11 +206,11 @@ func (r *routerCoreStd) loadCheckFunc(path string) (string, func(string) bool) { } // 调用FuncCreator创建check函数 - fn, err := r.FuncCreator.Create(typeString, fname) + fn, err := r.FuncCreator.CreateFunc(FuncCreateString, fname) if err == nil { return name, fn.(func(string) bool) } - // 无法获得校验函数抛出错误 + // 无法获得校验函数抛出错误,由RouterStd recover。 panic(fmt.Errorf(ErrFormarRouterStdLoadInvalidFunc, path, err)) } @@ -221,7 +221,7 @@ func (r *stdNode) setHandler(method string, params Params, handler HandlerFuncs) } for i := uint(0); i < 6; i++ { - if defaultRouterAnyMethod[i] == method { + if DefaultRouterCoreMethod[i] == method { r.params[i] = params r.handlers[i] = handler r.isany &^= 1 << i @@ -235,22 +235,19 @@ func (r *stdNode) setHandler(method string, params Params, handler HandlerFuncs) } func (r *stdNode) setHandlerAny(params Params, handler HandlerFuncs) { - // 设置标准Any - for i := uint(0); i < 6; i++ { - if r.isany>>i&0x1 == 0x1 || r.handlers[i] == nil { - r.params[i] = params - r.handlers[i] = handler - r.isany |= 1 << i - } - } - // 设置others any for _, method := range DefaultRouterAnyMethod { - i := getStringInIndex(method, defaultRouterAnyMethod) + i := getStringInIndex(method, DefaultRouterCoreMethod) if i == -1 { + // 设置others any if r.others == nil { r.others = make(map[string]stdOtherHandler) } r.others[method] = stdOtherHandler{any: true, params: params, handler: handler} + } else if r.isany>>i&0x1 == 0x1 || r.handlers[i] == nil { + // 设置标准Any + r.params[i] = params + r.handlers[i] = handler + r.isany |= 1 << i } } r.isany |= 0x40 @@ -371,16 +368,15 @@ func (r *stdNode) insertNodeConst(path string, nextNode *stdNode) *stdNode { return nextNode } +//nolint:cyclop,gocyclo func (r *stdNode) lookNode(searchKey string, params *Params) *stdNode { // constant match, return data - // 常量匹配,返回数据 if len(searchKey) == 0 && r.route != "" { return r } if len(searchKey) > 0 { // Traverse constant Node match - // 遍历常量Node匹配,数据量少使用二分查找无效 for _, child := range r.Cchildren { if child.path[0] >= searchKey[0] { length := len(child.path) @@ -394,7 +390,6 @@ func (r *stdNode) lookNode(searchKey string, params *Params) *stdNode { } // parameter matching, Check if there is a parameter match - // 参数匹配 检测是否存在参数匹配 if r.pnum != 0 { pos := strings.IndexByte(searchKey, '/') if pos == -1 { @@ -403,7 +398,6 @@ func (r *stdNode) lookNode(searchKey string, params *Params) *stdNode { currentKey, nextSearchKey := searchKey[:pos], searchKey[pos:] // check parameter matching - // 校验参数匹配 for _, child := range r.PVchildren { if child.check(currentKey) { if n := child.lookNode(nextSearchKey, params); n != nil { @@ -412,9 +406,6 @@ func (r *stdNode) lookNode(searchKey string, params *Params) *stdNode { } } } - - // 参数匹配 - // 变量Node依次匹配是否满足 for _, child := range r.Pchildren { if n := child.lookNode(nextSearchKey, params); n != nil { *params = params.Add(child.name, currentKey) @@ -423,27 +414,19 @@ func (r *stdNode) lookNode(searchKey string, params *Params) *stdNode { } } } - - // wildcard verification match - // If the current Node has a wildcard processing method that directly matches, the result is returned. - // 通配符校验匹配 - // 若当前Node有通配符处理方法直接匹配,返回结果。 + // If the current Node has a wildcard processing method that directly matches for _, child := range r.WVchildren { if child.check(searchKey) { *params = params.Add(child.name, searchKey) return child } } - - // If the current Node has a wildcard processing method that directly matches, the result is returned. - // 若当前Node有通配符处理方法直接匹配,返回结果。 + // If the current Node has a wildcard processing method that directly matches if r.Wchildren != nil { *params = params.Add(r.Wchildren.name, searchKey) return r.Wchildren } - // can't match, return nil - // 无法匹配,返回空 return nil } @@ -639,22 +622,21 @@ func stdRemoveNode(nodes []*stdNode, node *stdNode) []*stdNode { } /* -The string is cut according to the Node type. -将字符串按Node类型切割 -String path cutting example: -字符串路径切割例子: -/ [/] -/api/note/ [/api/note/] -//api/* [/api/ *] -//api/*name [/api/ *name] -/api/get/ [/api/get/] -/api/get [/api/get] -/api/:get [/api/ :get] -/api/:get/* [/api/ :get / *] -/api/:name/info/* [/api/ :name /info/ *] -/api/:name|^\\d+$/info [/api/ :name|^\d+$ /info] -/api/*|{^0/api\\S+$} [/api/ *|{^0/api\S+$}] -/api/*|^\\$\\d+$ [/api/ *|^\$\d+$] +The string is cut according to the Node type, String path cutting example: +将字符串按Node类型切割,字符串路径切割例子: + + / [/] + /api/note/ [/api/note/] + //api/* [//api/ *] + //api/*name [//api/ *name] + /api/get/ [/api/get/] + /api/get [/api/get] + /api/:get [/api/ :get] + /api/:get/* [/api/ :get / *] + /api/:name/info/* [/api/ :name /info/ *] + /api/:name|^\\d+$/info [/api/ :name|^\d+$ /info] + /api/*|{^0/api\\S+$} [/api/ *|^0/api\S+$] + /api/*|^\\$\\d+$ [/api/ *|^\$\d+$] */ func getSplitPath(key string) []string { if len(key) < 2 { @@ -664,9 +646,9 @@ func getSplitPath(key string) []string { key = "/" + key } var strs []string - var length = -1 - var isblock = 0 - var isconst = false + length := -1 + isblock := 0 + isconst := false for i := range key { // 块模式匹配 if isblock > 0 { @@ -677,7 +659,7 @@ func getSplitPath(key string) []string { isblock-- } if isblock > 0 { - strs[length] = strs[length] + key[i:i+1] + strs[length] += key[i : i+1] } continue } @@ -689,22 +671,16 @@ func getSplitPath(key string) []string { strs = append(strs, "") isconst = true } - case ':': - // 变量模式 + case ':', '*': + // 变量模式,通配符模式 isconst = false length++ strs = append(strs, "") - case '*': - // 通配符模式 - isconst = false - length++ - strs = append(strs, key[i:]) - return strs case '{': isblock++ continue } - strs[length] = strs[length] + key[i:i+1] + strs[length] += key[i : i+1] } return strs } diff --git a/server.go b/server.go index 811da2c..91ff761 100644 --- a/server.go +++ b/server.go @@ -8,12 +8,12 @@ import ( "crypto/x509" "crypto/x509/pkix" "errors" - "io/ioutil" "log" "math/big" "net" "net/http" "net/http/fcgi" + "os" "strings" "sync" "sync/atomic" @@ -27,48 +27,47 @@ type Server interface { Shutdown(context.Context) error } -// ServerStdConfig 定义serverStd使用的配置 -type ServerStdConfig struct { +// ServerConfig 定义serverStd使用的配置。 +type ServerConfig struct { // set default ServerHandler - Handler http.Handler + Handler http.Handler `alias:"handler" json:"-" xml:"-" yaml:"-"` // ReadTimeout is the maximum duration for reading the entire request, including the body. // // Because ReadTimeout does not let Handlers make per-request decisions on each request body's acceptable deadline or upload rate, // most users will prefer to use ReadHeaderTimeout. It is valid to use them both. - ReadTimeout TimeDuration `alias:"readtimeout" json:"readtimeout" description:"Http server read timeout."` + ReadTimeout TimeDuration `alias:"readtimeout" json:"readtimeout" xml:"readtimeout" yaml:"readtimeout" description:"Http server read timeout."` // ReadHeaderTimeout is the amount of time allowed to read request headers. // The connection's read deadline is reset after reading the headers and the Handler can decide what is considered too slow for the body. - ReadHeaderTimeout TimeDuration `alias:"readheadertimeout" json:"readheadertimeout" description:"Http server read header timeout."` // Go 1.8 - + ReadHeaderTimeout TimeDuration `alias:"readheadertimeout" json:"readheadertimeout" xml:"readheadertimeout" yaml:"readheadertimeout" description:"Http server read header timeout."` // WriteTimeout is the maximum duration before timing out writes of the response. // It is reset whenever a new request's header is read. // Like ReadTimeout, it does not let Handlers make decisions on a per-request basis. - WriteTimeout TimeDuration `alias:"writetimeout" json:"writetimeout" description:"Http server write timeout."` + WriteTimeout TimeDuration `alias:"writetimeout" json:"writetimeout" xml:"writetimeout" yaml:"writetimeout" description:"Http server write timeout."` // IdleTimeout is the maximum amount of time to wait for the next request when keep-alives are enabled. // If IdleTimeout is zero, the value of ReadTimeout is used. If both are zero, ReadHeaderTimeout is used. - IdleTimeout TimeDuration `alias:"idletimeout" json:"idletimeout" description:"Http server idle timeout."` // Go 1.8 + IdleTimeout TimeDuration `alias:"idletimeout" json:"idletimeout" xml:"idletimeout" yaml:"idletimeout" description:"Http server idle timeout."` // MaxHeaderBytes controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. // It does not limit the size of the request body. If zero, DefaultMaxHeaderBytes is used. - MaxHeaderBytes int `alias:"maxheaderbytes" json:"maxheaderbytes" description:"Http server max header size."` + MaxHeaderBytes int `alias:"maxheaderbytes" json:"maxheaderbytes" xml:"maxheaderbytes" yaml:"maxheaderbytes" description:"Http server max header size."` // ErrorLog specifies an optional logger for errors accepting // connections, unexpected behavior from handlers, and // underlying FileSystem errors. // If nil, logging is done via the log package's standard logger. - ErrorLog *log.Logger // Go 1.3 + ErrorLog *log.Logger `alias:"errorlog" json:"-" xml:"-" yaml:"-"` // BaseContext optionally specifies a function that returns the base context for incoming requests on this server. // The provided Listener is the specific Listener that's about to start accepting requests. // If BaseContext is nil, the default is context.Background(). If non-nil, it must return a non-nil context. - BaseContext func(net.Listener) context.Context `alias:"basecontext" json:"-"` // Go 1.13 + BaseContext func(net.Listener) context.Context `alias:"basecontext" json:"-" xml:"-" yaml:"-"` // ConnContext optionally specifies a function that modifies the context used for a new connection c. // The provided ctx is derived from the base context and has a ServerContextKey value. - ConnContext func(context.Context, net.Conn) context.Context `alias:"conncontext" json:"-"` // Go 1.13 + ConnContext func(context.Context, net.Conn) context.Context `alias:"conncontext" json:"-" xml:"-" yaml:"-"` } // serverStd 定义使用net/http启动http server。 @@ -81,14 +80,14 @@ type serverStd struct { Counter int64 } -type serverStdMetadata struct { - Health bool `json:"health" xml:"health"` - Name string `json:"name" xml:"name"` - Ports []string `json:"ports" xml:"ports"` - ErrorCount int64 `json:"error_count" xml:"error-count"` +type MetadataServer struct { + Health bool `alias:"health" json:"health" xml:"health" yaml:"health"` + Name string `alias:"name" json:"name" xml:"name" yaml:"name"` + Ports []string `alias:"ports" json:"ports" xml:"ports" yaml:"ports"` + ErrorCount int64 `alias:"errorcount" json:"errorcount" xml:"errorcount" yaml:"errorcount"` } -// serverFcgi 定义fastcgi server +// serverFcgi 定义fastcgi server。 type serverFcgi struct { http.Handler sync.Mutex @@ -97,41 +96,52 @@ type serverFcgi struct { // ServerListenConfig 定义一个通用的端口监听配置,监听https仅支持单证书。 type ServerListenConfig struct { - NewListen func(string, string) (net.Listener, error) `alias:"newlisten" json:"-" description:"create listener func, default: net.Listen"` - Addr string `alias:"addr" json:"addr" description:"Listen addr."` - HTTPS bool `alias:"https" json:"https" description:"Is https."` - HTTP2 bool `alias:"http2" json:"http2" description:"Is http2."` - Mutual bool `alias:"mutual" json:"mutual" description:"Is mutual tls."` - Certfile string `alias:"certfile" json:"certfile" description:"Http server cert file."` - Keyfile string `alias:"keyfile" json:"keyfile" description:"Http server key file."` - Trustfile string `alias:"trustfile" json:"trustfile" description:"Http client ca file."` - Certificate *x509.Certificate `alias:"certificate" json:"certificate" description:"https use tls certificate."` + Addr string `alias:"addr" json:"addr" xml:"addr" yaml:"addr" description:"Listen addr."` + HTTPS bool `alias:"https" json:"https" xml:"https" yaml:"https" description:"Is https."` + HTTP2 bool `alias:"http2" json:"http2" xml:"http2" yaml:"http2" description:"Is http2."` + Mutual bool `alias:"mutual" json:"mutual" xml:"mutual" yaml:"mutual" description:"Is mutual tls."` + Certfile string `alias:"certfile" json:"certfile" xml:"certfile" yaml:"certfile" description:"Http server cert file."` + Keyfile string `alias:"keyfile" json:"keyfile" xml:"keyfile" yaml:"keyfile" description:"Http server key file."` + Trustfile string `alias:"trustfile" json:"trustfile" xml:"trustfile" yaml:"trustfile" description:"Http client ca file."` + Certificate *x509.Certificate `alias:"certificate" json:"certificate" xml:"certificate" yaml:"certificate" description:"https use tls certificate."` } -// NewServerStd 创建一个标准server。 -func NewServerStd(arg interface{}) Server { +// NewServer 创建一个标准server。 +func NewServer(config *ServerConfig) Server { + if config == nil { + config = &ServerConfig{} + } srv := &serverStd{ Server: &http.Server{ - ReadTimeout: 60 * time.Second, - ReadHeaderTimeout: 60 * time.Second, - WriteTimeout: 60 * time.Second, - IdleTimeout: 60 * time.Second, - TLSNextProto: nil, + Handler: config.Handler, + ReadTimeout: GetAnyDefault(time.Duration(config.ReadTimeout), DefaultServerReadTimeout), + ReadHeaderTimeout: GetAnyDefault(time.Duration(config.ReadHeaderTimeout), DefaultServerReadHeaderTimeout), + WriteTimeout: GetAnyDefault(time.Duration(config.WriteTimeout), DefaultServerWriteTimeout), + IdleTimeout: GetAnyDefault(time.Duration(config.IdleTimeout), DefaultServerIdleTimeout), + MaxHeaderBytes: config.MaxHeaderBytes, + ErrorLog: config.ErrorLog, + BaseContext: config.BaseContext, + ConnContext: config.ConnContext, }, Logger: DefaultLoggerNull, } // 捕捉net/http.Server输出的error内容。 - srv.Server.ErrorLog = log.New(srv, "", 0) - ConvertTo(arg, srv.Server) + if srv.ErrorLog == nil { + srv.ErrorLog = log.New(srv, "", 0) + } return srv } // Mount 方法获取ContextKeyApp.(Logger)用于输出http.Server错误日志。 // 获取ContextKeyApp.(http.Handler)作为http.Server的处理对象。 func (srv *serverStd) Mount(ctx context.Context) { - srv.SetHandler(ctx.Value(ContextKeyApp).(http.Handler)) - srv.BaseContext = func(net.Listener) context.Context { - return ctx + if srv.Handler == nil { + srv.SetHandler(ctx.Value(ContextKeyApp).(http.Handler)) + } + if srv.BaseContext == nil { + srv.BaseContext = func(net.Listener) context.Context { + return ctx + } } log, ok := ctx.Value(ContextKeyApp).(Logger) if ok { @@ -140,10 +150,10 @@ func (srv *serverStd) Mount(ctx context.Context) { } // Unmount 方法等待DefaulerServerShutdownWait(默认60s)优雅停机。 -func (srv *serverStd) Unmount(ctx context.Context) { +func (srv *serverStd) Unmount(context.Context) { ctx, cancel := context.WithTimeout(context.Background(), DefaultServerShutdownWait) defer cancel() - srv.Shutdown(ctx) + _ = srv.Shutdown(ctx) } // SetHandler 方法设置server的http处理者。 @@ -159,7 +169,7 @@ func (srv *serverStd) Serve(ln net.Listener) error { srv.Ports = append(srv.Ports, ln.Addr().String()) srv.Mutex.Unlock() err := srv.Server.Serve(ln) - if err == http.ErrServerClosed { + if errors.Is(err, http.ErrServerClosed) { err = nil } return err @@ -171,19 +181,21 @@ func (srv *serverStd) ServeConn(conn net.Conn) { if srv.localListener.Ch == nil { srv.localListener.Ch = make(chan net.Conn) srv.Ports = append(srv.Ports, srv.localListener.Addr().String()) - go srv.Server.Serve(&srv.localListener) + go func() { + _ = srv.Server.Serve(&srv.localListener) + }() } srv.Mutex.Unlock() srv.localListener.Ch <- conn } // Metadata 方法返回serverStd元数据。 -func (srv *serverStd) Metadata() interface{} { +func (srv *serverStd) Metadata() any { srv.Mutex.Lock() defer srv.Mutex.Unlock() - return serverStdMetadata{ + return MetadataServer{ Health: true, - Name: "eudore.ServerStd", + Name: "eudore.serverStd", Ports: srv.Ports, ErrorCount: atomic.LoadInt64(&srv.Counter), } @@ -191,15 +203,28 @@ func (srv *serverStd) Metadata() interface{} { func (srv *serverStd) Write(p []byte) (n int, err error) { atomic.AddInt64(&srv.Counter, 1) + log := srv.Logger.WithField(ParamDepth, "disable").WithField(ParamCaller, "*serverStd.ErrorLog.Write") strs := strings.Split(string(p), "\n") if strings.HasPrefix(strs[0], "http: panic serving ") { lines := []string{} for i := 2; i < len(strs)-1; i += 2 { - lines = append(lines, strs[i]+" "+strs[i+1][2:]) + if strings.HasPrefix(strs[i], "created by ") { + strs[i] = strs[i][11:] + } else { + end := strings.LastIndexByte(strs[i], '(') + if end != -1 { + strs[i] = strs[i][:end] + } + } + pos := strings.IndexByte(strs[i+1], ' ') + if pos != -1 { + strs[i+1] = strs[i+1][:pos] + } + lines = append(lines, strings.TrimPrefix(strs[i+1], "\t")+" "+strs[i]) } - srv.Logger.WithField("depth", "disable").WithField("depth", lines).Errorf("%s %s", strs[0], strs[1][:len(strs[1])-1]) + log.WithField("stack", lines).Errorf("%s %s", strs[0], strs[1][:len(strs[1])-1]) } else { - srv.Logger.WithField("depth", "disable").Errorf(strs[0]) + log.Errorf(strs[0]) } return 0, nil } @@ -243,10 +268,10 @@ func (srv *serverFcgi) Mount(ctx context.Context) { } // Unmount 方法等待DefaulerServerShutdownWait(默认60s)优雅停机。 -func (srv *serverFcgi) Unmount(ctx context.Context) { +func (srv *serverFcgi) Unmount(context.Context) { ctx, cancel := context.WithTimeout(context.Background(), DefaultServerShutdownWait) defer cancel() - srv.Shutdown(ctx) + _ = srv.Shutdown(ctx) } // SetHandler 方法设置fcgi处理对象。 @@ -263,10 +288,10 @@ func (srv *serverFcgi) Serve(ln net.Listener) error { } // Shutdown 方法关闭fcgi关闭监听。 -func (srv *serverFcgi) Shutdown(ctx context.Context) error { +func (srv *serverFcgi) Shutdown(context.Context) error { srv.Lock() defer srv.Unlock() - var errs errormulit + var errs mulitError for _, ln := range srv.listeners { errs.HandleError(ln.Close()) } @@ -275,11 +300,8 @@ func (srv *serverFcgi) Shutdown(ctx context.Context) error { // Listen 方法使ServerListenConfig实现serverListener接口,用于使用对象创建监听。 func (slc *ServerListenConfig) Listen() (net.Listener, error) { - if slc.NewListen == nil { - slc.NewListen = net.Listen - } // set default port - if len(slc.Addr) == 0 { + if slc.Addr == "" { if slc.HTTPS { slc.Addr = ":80" } else { @@ -287,7 +309,7 @@ func (slc *ServerListenConfig) Listen() (net.Listener, error) { } } if !slc.HTTPS { - return slc.NewListen("tcp", slc.Addr) + return DefaultServerListen("tcp", slc.Addr) } // set tls @@ -307,7 +329,7 @@ func (slc *ServerListenConfig) Listen() (net.Listener, error) { // set mutual tls if slc.Mutual { - data, err := ioutil.ReadFile(slc.Trustfile) + data, err := os.ReadFile(slc.Trustfile) if err != nil { return nil, err } @@ -317,7 +339,7 @@ func (slc *ServerListenConfig) Listen() (net.Listener, error) { config.ClientAuth = tls.RequireAndVerifyClientCert } - ln, err := slc.NewListen("tcp", slc.Addr) + ln, err := DefaultServerListen("tcp", slc.Addr) if err != nil { return nil, err } diff --git a/util.go b/util.go index 99b0bb0..3bb8160 100644 --- a/util.go +++ b/util.go @@ -2,8 +2,10 @@ package eudore import ( "bytes" + "crypto/rand" "encoding/json" "fmt" + "io" "reflect" "strconv" "strings" @@ -15,7 +17,7 @@ type contextKey struct { } // NewContextKey 定义context key。 -func NewContextKey(key string) interface{} { +func NewContextKey(key string) any { return contextKey{key} } @@ -36,8 +38,8 @@ func NewParamsRoute(path string) Params { params := make(Params, 0, len(args)*2+2) params = append(params, ParamRoute, route) for _, str := range args { - k, v := split2byte(str, '=') - if v != "" { + k, v, ok := strings.Cut(str, "=") + if ok && v != "" { params = append(params, k, v) } } @@ -53,7 +55,7 @@ func (p Params) Clone() Params { // CombineWithRoute 方法将params数据合并到p,用于路由路径合并。 func (p Params) CombineWithRoute(params Params) Params { - p[1] = p[1] + params[1] + p[1] += params[1] for i := 2; i < len(params); i += 2 { p = p.Set(params[i], params[i+1]) } @@ -111,7 +113,7 @@ func (p Params) Set(key, val string) Params { return append(p, key, val) } -// Del 方法删除一个参数值 +// Del 方法删除一个参数值。 func (p Params) Del(key string) { for i := 0; i < len(p); i += 2 { if p[i] == key { @@ -120,391 +122,11 @@ func (p Params) Del(key string) { } } -// TimeDuration 定义time.Duration类型处理json -type TimeDuration time.Duration - -// String 方法格式化输出时间。 -func (d TimeDuration) String() string { - return time.Duration(d).String() -} - -// MarshalJSON 方法实现json序列化输出。 -func (d TimeDuration) MarshalJSON() ([]byte, error) { - return json.Marshal(time.Duration(d).String()) -} - -// UnmarshalJSON 方法实现解析json格式时间。 -func (d *TimeDuration) UnmarshalJSON(b []byte) error { - str := string(b) - if str != "" && str[0] == '"' && str[len(str)-1] == '"' { - str = str[1 : len(str)-1] - } - // parse int64 - val, err := strconv.ParseInt(str, 10, 64) - if err == nil { - *d = TimeDuration(val) - return nil - } - // parse string - t, err := time.ParseDuration(str) - if err == nil { - *d = TimeDuration(t) - return nil - } - return fmt.Errorf("invalid duration type %T, value: '%s'", b, b) -} - -// split2byte internal function, splits two strings into two segments using the first specified byte, and returns "", str if there is no split symbol. -// -// split2byte 内部函数,使用第一个指定byte两字符串分割成两段,如果不存在分割符号,返回"", str。 -func split2byte(str string, b byte) (string, string) { - pos := strings.IndexByte(str, b) - if pos == -1 { - return str, "" - } - return str[:pos], str[pos+1:] -} - -// GetBool 函数转换bool、int、uint、float、string成bool。 -func GetBool(i interface{}) bool { - iValue := reflect.ValueOf(i) - switch iValue.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - return iValue.Int() != 0 - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - return iValue.Uint() != 0 - case reflect.Float32, reflect.Float64: - return iValue.Float() != 0 - case reflect.String: - str := iValue.String() - return str != "" && str != "true" && str != "1" - default: - return false - } -} - -// GetInt 函数转换一个bool、int、uint、float、string类型成int,或者返回第一个非零值。 -func GetInt(i interface{}, nums ...int) int { - var number int - iValue := reflect.ValueOf(i) - switch iValue.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - number = int(iValue.Int()) - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - number = int(iValue.Uint()) - case reflect.Float32, reflect.Float64: - number = int(iValue.Float()) - case reflect.String: - if v, err := strconv.Atoi(iValue.String()); err == nil { - number = v - } - } - if number != 0 { - return number - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetInt64 函数转换一个bool、int、uint、float、string类型成int64,或者返回第一个非零值。 -func GetInt64(i interface{}, nums ...int64) int64 { - var number int64 - iValue := reflect.ValueOf(i) - switch iValue.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - number = iValue.Int() - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - number = int64(iValue.Uint()) - case reflect.Float32, reflect.Float64: - number = int64(iValue.Float()) - case reflect.String: - if v, err := strconv.ParseInt(iValue.String(), 10, 64); err == nil { - number = v - } - } - if number != 0 { - return number - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetUint 函数转换一个bool、int、uint、float、string类型成uint,或者返回第一个非零值。 -func GetUint(i interface{}, nums ...uint) uint { - var number uint - iValue := reflect.ValueOf(i) - switch iValue.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - number = uint(iValue.Int()) - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - number = uint(iValue.Uint()) - case reflect.Float32, reflect.Float64: - number = uint(iValue.Float()) - case reflect.String: - if v, err := strconv.ParseUint(iValue.String(), 10, 64); err == nil { - number = uint(v) - } - } - if number != 0 { - return number - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetUint64 函数转换一个bool、int、uint、float、string类型成uint64,或者返回第一个非零值。 -func GetUint64(i interface{}, nums ...uint64) uint64 { - var number uint64 - iValue := reflect.ValueOf(i) - switch iValue.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - number = uint64(iValue.Int()) - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - number = iValue.Uint() - case reflect.Float32, reflect.Float64: - number = uint64(iValue.Float()) - case reflect.String: - if v, err := strconv.ParseUint(iValue.String(), 10, 64); err == nil { - number = v - } - } - if number != 0 { - return number - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetFloat32 函数转换一个bool、int、uint、float、string类型成float32,或者返回第一个非零值。 -func GetFloat32(i interface{}, nums ...float32) float32 { - var number float32 - iValue := reflect.ValueOf(i) - switch iValue.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - number = float32(iValue.Int()) - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - number = float32(iValue.Uint()) - case reflect.Float32, reflect.Float64: - number = float32(iValue.Float()) - case reflect.String: - if v, err := strconv.ParseFloat(iValue.String(), 32); err == nil { - return float32(v) - } - } - if number != 0 { - return number - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetFloat64 函数转换一个bool、int、uint、float、string类型成float64,或者返回第一个非零值。 -func GetFloat64(i interface{}, nums ...float64) float64 { - var number float64 - iValue := reflect.ValueOf(i) - switch iValue.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - number = float64(iValue.Int()) - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - number = float64(iValue.Uint()) - case reflect.Float32, reflect.Float64: - number = iValue.Float() - case reflect.String: - if v, err := strconv.ParseFloat(iValue.String(), 64); err == nil { - return v - } - } - if number != 0 { - return number - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetString 方法转换一个bool、int、uint、float、string成string类型,或者返回第一个非零值,如果参数类型是string必须是非空才会作为返回值。 -func GetString(i interface{}, strs ...string) string { - var str string - iValue := reflect.ValueOf(i) - switch iValue.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - str = strconv.FormatInt(iValue.Int(), 10) - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - str = strconv.FormatUint(iValue.Uint(), 10) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(iValue.Float(), 'f', -1, 64) - case reflect.String: - str = iValue.String() - case reflect.Bool: - str = strconv.FormatBool(iValue.Bool()) - default: - switch val := i.(type) { - case fmt.Stringer: - str = val.String() - case []byte: - str = string(val) - } - } - if str != "" { - return str - } - for _, i := range strs { - if i != "" { - return i - } - } - return "" -} - -// GetBytes 方法断言[]byte类型或使用GetString方法转换成string类型 -func GetBytes(i interface{}) []byte { - body, ok := i.([]byte) - if ok { - return body - } - - str := GetString(i) - if str != "" { - return []byte(str) - } - - return nil -} - -// GetStrings 转换string、[]strng、[]interface{}成[]string。 -func GetStrings(i interface{}) []string { - if i == nil { - return nil - } - switch val := i.(type) { - case string: - return []string{val} - case []string: - return val - case []interface{}: - strs := make([]string, len(val)) - for i := range val { - strs[i] = GetString(val[i]) - } - return strs - } - return nil -} - -// GetStringBool 使用strconv.ParseBool解析。 -func GetStringBool(str string) bool { - if v, err := strconv.ParseBool(str); err == nil { - return v - } - return false -} - -// GetStringInt 使用strconv.Atoi解析返回数据,如果解析返回错误使用第一个非零值。 -func GetStringInt(str string, nums ...int) int { - if v, err := strconv.Atoi(str); err == nil && v != 0 { - return v - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetStringInt64 使用strconv.ParseInt解析返回数据,如果解析返回错误使用第一个非零值。 -func GetStringInt64(str string, nums ...int64) int64 { - if v, err := strconv.ParseInt(str, 10, 64); err == nil && v != 0 { - return v - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetStringUint 使用strconv.ParseUint解析返回数据,如果解析返回错误使用第一个非零值。 -func GetStringUint(str string, nums ...uint) uint { - if v, err := strconv.ParseUint(str, 10, 64); err == nil && v != 0 { - return uint(v) - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetStringUint64 使用strconv.ParseUint解析返回数据,如果解析返回错误使用第一个非零值。 -func GetStringUint64(str string, nums ...uint64) uint64 { - if v, err := strconv.ParseUint(str, 10, 64); err == nil && v != 0 { - return v - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetStringFloat32 使用strconv.ParseFloa解析数据,如果解析返回错误使用第一个第一个非零值。 -func GetStringFloat32(str string, nums ...float32) float32 { - if v, err := strconv.ParseFloat(str, 32); err == nil && v != 0 { - return float32(v) - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - -// GetStringFloat64 使用strconv.ParseFloa解析数据,如果解析返回错误使用第一个第一个非零值。 -func GetStringFloat64(str string, nums ...float64) float64 { - if v, err := strconv.ParseFloat(str, 64); err == nil && v != 0 { - return v - } - for _, i := range nums { - if i != 0 { - return i - } - } - return 0 -} - // GetWarp 对象封装Get函数提供类型转换功能。 -type GetWarp func(string) interface{} +type GetWarp func(string) any // NewGetWarp 函数创建一个getwarp处理类型转换。 -func NewGetWarp(fn func(string) interface{}) GetWarp { +func NewGetWarp(fn func(string) any) GetWarp { return fn } @@ -515,87 +137,116 @@ func NewGetWarpWithConfig(c Config) GetWarp { // NewGetWarpWithApp 函数使用App创建getwarp。 func NewGetWarpWithApp(app *App) GetWarp { - return func(key string) interface{} { + return func(key string) any { return app.Get(key) } } -// NewGetWarpWithMapString 函数使用map[string]interface{}创建getwarp。 -func NewGetWarpWithMapString(data map[string]interface{}) GetWarp { - return func(key string) interface{} { +// NewGetWarpWithMapString 函数使用map[string]any创建getwarp。 +func NewGetWarpWithMapString(data map[string]any) GetWarp { + return func(key string) any { return data[key] } } // NewGetWarpWithObject 函数使用map或创建getwarp。 -func NewGetWarpWithObject(obj interface{}) GetWarp { - return func(key string) interface{} { - return Get(obj, key) +func NewGetWarpWithObject(obj any) GetWarp { + return func(key string) any { + return GetAnyByPath(obj, key) } } -// GetInterface 方法获取interface类型的配置值。 -func (fn GetWarp) GetInterface(key string) interface{} { +// GetAny 方法获取any类型的配置值。 +func (fn GetWarp) GetAny(key string) any { return fn(key) } // GetBool 方法获取bool类型的配置值。 func (fn GetWarp) GetBool(key string) bool { - return GetBool(fn(key)) + return GetAny[bool](fn(key)) } // GetInt 方法获取int类型的配置值。 func (fn GetWarp) GetInt(key string, vals ...int) int { - return GetInt(fn(key), vals...) + return GetAny(fn(key), vals...) } // GetUint 方法取获取uint类型的配置值。 func (fn GetWarp) GetUint(key string, vals ...uint) uint { - return GetUint(fn(key), vals...) + return GetAny(fn(key), vals...) } // GetInt64 方法int64类型的配置值。 func (fn GetWarp) GetInt64(key string, vals ...int64) int64 { - return GetInt64(fn(key), vals...) + return GetAny(fn(key), vals...) } // GetUint64 方法取获取uint64类型的配置值。 func (fn GetWarp) GetUint64(key string, vals ...uint64) uint64 { - return GetUint64(fn(key), vals...) + return GetAny(fn(key), vals...) } // GetFloat32 方法取获取float32类型的配置值。 func (fn GetWarp) GetFloat32(key string, vals ...float32) float32 { - return GetFloat32(fn(key), vals...) + return GetAny(fn(key), vals...) } // GetFloat64 方法取获取float64类型的配置值。 func (fn GetWarp) GetFloat64(key string, vals ...float64) float64 { - return GetFloat64(fn(key), vals...) + return GetAny(fn(key), vals...) } -// GetString 方法获取一个字符串,如果字符串为空返回其他默认非空字符串, +// GetString 方法获取一个字符串,如果字符串为空返回其他默认非空字符串。 func (fn GetWarp) GetString(key string, vals ...string) string { - return GetString(fn(key), vals...) + return GetStringByAny(fn(key), vals...) } -// GetBytes 方法获取[]byte类型的配置值,如果是字符串类型会转换成[]byte。 -func (fn GetWarp) GetBytes(key string) []byte { - return GetBytes(fn(key)) +// TimeDuration 定义time.Duration类型处理json。 +type TimeDuration time.Duration + +// String 方法格式化输出时间。 +func (d TimeDuration) String() string { + return time.Duration(d).String() } -// GetStrings 方法获取[]string值 -func (fn GetWarp) GetStrings(key string) []string { - return GetStrings(fn(key)) +// MarshalText 方法实现json序列化输出。 +func (d TimeDuration) MarshalText() ([]byte, error) { + return []byte(time.Duration(d).String()), nil } -// errormulit 实现多个error组合。 -type errormulit struct { +// UnmarshalJSON 方法实现解析json格式时间。 +func (d *TimeDuration) UnmarshalJSON(b []byte) error { + if len(b) > 0 && b[0] == '"' && b[len(b)-1] == '"' { + b = b[1 : len(b)-1] + } + return d.UnmarshalText(b) +} + +// UnmarshalText 方法实现解析时间。 +func (d *TimeDuration) UnmarshalText(b []byte) error { + str := string(b) + // parse int64 + val, err := strconv.ParseInt(str, 10, 64) + if err == nil { + *d = TimeDuration(val) + return nil + } + // parse string + t, err := time.ParseDuration(str) + if err == nil { + *d = TimeDuration(t) + return nil + } + return fmt.Errorf("invalid duration value: '%s'", b) +} + +// mulitError 实现多个error组合。 +type mulitError struct { errs []error } // HandleError 实现处理多个错误,如果非空则保存错误。 -func (err *errormulit) HandleError(errs ...error) { +func (err *mulitError) HandleError(errs ...error) { for _, e := range errs { if e != nil { err.errs = append(err.errs, e) @@ -604,12 +255,12 @@ func (err *errormulit) HandleError(errs ...error) { } // Error 方法实现error接口,返回错误描述。 -func (err *errormulit) Error() string { +func (err *mulitError) Error() string { return fmt.Sprint(err.errs) } // GetError 方法返回错误,如果没有保存的错误则返回空。 -func (err *errormulit) Unwrap() error { +func (err *mulitError) Unwrap() error { switch len(err.errs) { case 0: return nil @@ -620,62 +271,338 @@ func (err *errormulit) Unwrap() error { } } -// NewErrorStatusCode 方法组合ErrorStatus和ErrorCode。 -func NewErrorStatusCode(err error, status, code int) error { +// NewErrorWithStatusCode 方法组合ErrorStatus和ErrorCode。 +func NewErrorWithStatusCode(err error, status, code int) error { if code > 0 { - err = errorCode{err, code} + err = codeError{err, code} } if status > 0 { - err = errorStatus{err, status} + err = statusError{err, status} } return err } -// NewErrorStatus 方法封装error实现Status方法。 -func NewErrorStatus(err error, status int) error { +// NewErrorWithStatus 方法封装error实现Status方法。 +func NewErrorWithStatus(err error, status int) error { if status > 0 { - return errorStatus{err, status} + return statusError{err, status} } return err } -type errorStatus struct { +type statusError struct { err error status int } -func (err errorStatus) Error() string { +func (err statusError) Error() string { return err.err.Error() } -func (err errorStatus) Unwrap() error { +func (err statusError) Unwrap() error { return err.err } -func (err errorStatus) Status() int { +func (err statusError) Status() int { return err.status } -// NewErrorCode 方法封装error实现Code方法。 -func NewErrorCode(err error, code int) error { +// NewErrorWithCode 方法封装error实现Code方法。 +func NewErrorWithCode(err error, code int) error { if code > 0 { - return errorCode{err, code} + return codeError{err, code} } return err } -type errorCode struct { +type codeError struct { err error code int } -func (err errorCode) Error() string { +func (err codeError) Error() string { return err.err.Error() } -func (err errorCode) Unwrap() error { + +func (err codeError) Unwrap() error { return err.err } -func (err errorCode) Code() int { +func (err codeError) Code() int { return err.code } + +func clearCap[T any](s []T) []T { + l := len(s) + return s[:l:l] +} + +func cutOmit(s string) (string, bool) { + if strings.HasSuffix(s, ",omitempty") { + return s[:len(s)-10], true + } + return s, false +} + +func sliceIndex[T comparable](vals []T, val T) int { + for i := range vals { + if val == vals[i] { + return i + } + } + return -1 +} + +func sliceLastIndex[T comparable](vals []T, val T) int { + for i := len(vals) - 1; i > -1; i-- { + if val == vals[i] { + return i + } + } + return -1 +} + +func sliceFilter[T any](s []T, fn func(T) bool) []T { + size := 0 + b := make([]bool, len(s)) + for i := range s { + b[i] = fn(s[i]) + if b[i] { + size++ + } + } + if size == len(s) { + return s + } + + n := make([]T, 0, size) + for i := range b { + if b[i] { + n = append(n, s[i]) + } + } + return n +} + +// GetAnyDefault 函数返回非空值。 +func GetAnyDefault[T comparable](arg1, arg2 T) T { + var zero T + if arg1 != zero { + return arg1 + } + return arg2 +} + +// GetAnyDefaults 函数返回第一个非空值。 +func GetAnyDefaults[T comparable](args ...T) T { + var zero T + for i := range args { + if args[i] != zero { + return args[i] + } + } + return zero +} + +func SetAnyDefault[T any](arg1, arg2 *T) { + v1 := reflect.Indirect(reflect.ValueOf(arg1)) + v2 := reflect.Indirect(reflect.ValueOf(arg2)) + if v1.Kind() == reflect.Struct && v1.Type() == v2.Type() { + for i := 0; i < v1.NumField(); i++ { + f1, f2 := v1.Field(i), v2.Field(i) + if f1.CanSet() && !f2.IsZero() { + f1.Set(f2) + } + } + } +} + +// TypeNumber 定义泛型数值类型集合。 +type TypeNumber interface { + int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64 | complex64 | complex128 +} + +// GetAny 函数类型Value转换成另外一个类型。 +func GetAny[T string | bool | TypeNumber](s any, defaults ...T) T { + var t, zero T + if s != nil { + sValue := reflect.ValueOf(s) + tType := reflect.TypeOf(t) + switch { + case sValue.Type() == tType: + t = sValue.Interface().(T) + case sValue.Kind() == tType.Kind(): + t = sValue.Convert(tType).Interface().(T) + case sValue.Kind() == reflect.String: + t = GetAnyByString(sValue.String(), defaults...) + case tType.Kind() == reflect.String: + t = any(GetStringByAny(s)).(T) + case sValue.CanConvert(tType): + t = sValue.Convert(tType).Interface().(T) + } + if t != zero { + return t + } + } + + for _, value := range defaults { + if value != zero { + return value + } + } + return t +} + +// GetStringByAny 函数将any转换成string +// +//nolint:cyclop,gocyclo +func GetStringByAny(i any, strs ...string) string { + var str string + switch v := i.(type) { + case string: + str = v + case int: + str = strconv.FormatInt(int64(v), 10) + case uint: + str = strconv.FormatUint(uint64(v), 10) + case float64: + str = strconv.FormatFloat(v, 'f', -1, 64) + case bool: + str = strconv.FormatBool(v) + case []byte: + str = string(v) + case fmt.Stringer: + str = v.String() + case int64: + str = strconv.FormatInt(v, 10) + case int32: + str = strconv.FormatInt(int64(v), 10) + case int16: + str = strconv.FormatInt(int64(v), 10) + case int8: + str = strconv.FormatInt(int64(v), 10) + case uint64: + str = strconv.FormatUint(v, 10) + case uint32: + str = strconv.FormatUint(uint64(v), 10) + case uint16: + str = strconv.FormatUint(uint64(v), 10) + case uint8: + str = strconv.FormatUint(uint64(v), 10) + case float32: + str = strconv.FormatFloat(float64(v), 'f', -1, 32) + case complex64: + str = strconv.FormatComplex(complex128(v), 'f', -1, 64) + case complex128: + str = strconv.FormatComplex(v, 'f', -1, 128) + default: + str = fmt.Sprint(i) + } + + if str != "" { + return str + } + for _, i := range strs { + if i != "" { + return i + } + } + return "" +} + +// GetStringRandom 函数返回指定长度随机字符串。 +func GetStringRandom(length int) string { + buf := make([]byte, length) + io.ReadFull(rand.Reader, buf) + return fmt.Sprintf("%x", buf) +} + +// GetAnyByString 函数将字符串转换为其他值。 +func GetAnyByString[T string | bool | TypeNumber | time.Time | time.Duration](str string, defaults ...T) T { + val, _ := GetAnyByStringWithError(str, defaults...) + return val +} + +// GetAnyByStringWithError 函数将字符串转换成泛型数值。 +// +//nolint:cyclop,funlen,gocyclo +func GetAnyByStringWithError[T string | bool | TypeNumber | time.Time | time.Duration](str string, defaults ...T) (T, error) { + var zero T + var val any + var err error + switch any(zero).(type) { + case int: + val, err = strconv.Atoi(str) + case float64: + val, err = strconv.ParseFloat(str, 64) + case string: + val = str + case bool: + val, err = strconv.ParseBool(str) + case int8: + var v int64 + v, err = strconv.ParseInt(str, 10, 8) + val = int8(v) + case int16: + var v int64 + v, err = strconv.ParseInt(str, 10, 16) + val = int16(v) + case int32: + var v int64 + v, err = strconv.ParseInt(str, 10, 16) + val = int32(v) + case int64: + val, err = strconv.ParseInt(str, 10, 64) + case uint: + var v uint64 + v, err = strconv.ParseUint(str, 10, 32) + val = uint(v) + case uint8: + var v uint64 + v, err = strconv.ParseUint(str, 10, 8) + val = uint8(v) + case uint16: + var v uint64 + v, err = strconv.ParseUint(str, 10, 16) + val = uint16(v) + case uint32: + var v uint64 + v, err = strconv.ParseUint(str, 10, 32) + val = uint32(v) + case uint64: + val, err = strconv.ParseUint(str, 10, 64) + case float32: + var v float64 + v, err = strconv.ParseFloat(str, 32) + val = float32(v) + case complex64: + var v complex128 + v, err = strconv.ParseComplex(str, 64) + val = complex64(v) + case complex128: + val, err = strconv.ParseComplex(str, 128) + case time.Duration: + val, err = time.ParseDuration(str) + case time.Time: + var v time.Time + for i, f := range DefaultValueParseTimeFormats { + if DefaultValueParseTimeFixed[i] && len(str) != len(f) { + continue + } + v, err = time.Parse(f, str) + if err == nil { + break + } + } + val = v + } + if val != zero { + return val.(T), err + } + for _, value := range defaults { + if value != zero { + return value, err + } + } + return zero, err +} diff --git a/util_16.go b/util_16.go deleted file mode 100644 index 2bd7fa7..0000000 --- a/util_16.go +++ /dev/null @@ -1,66 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package eudore - -import ( - "embed" - "net/http" - "strings" -) - -func init() { - formartErr := "error: %v" - formartErr13 := "error: %w" - ErrFormatRouterStdAddController = strings.Replace(ErrFormatRouterStdAddController, formartErr, formartErr13, 1) - ErrFormatRouterStdAddHandlerExtend = strings.Replace(ErrFormatRouterStdAddHandlerExtend, formartErr, formartErr13, 1) - ErrFormatRouterStdRegisterHandlersRecover = strings.Replace(ErrFormatRouterStdRegisterHandlersRecover, formartErr, formartErr13, 1) - - DefaultHandlerExtend.RegisterHandlerExtend("", NewExtendFuncEmbed) -} - -func NewExtendFuncEmbed(path string, f embed.FS) HandlerFunc { - return NewHandlerEmbedFunc(f, strings.Split(getRouteParam(path, "dir"), ";")...) -} - -// NewHandlerEmbedFunc 函数使用embed.FS和指定目录文件处理响应,依次寻找dirs多个目录是否存在文件,否则使用embed.FS作为默认FS返回响应。 -func NewHandlerEmbedFunc(f embed.FS, dirs ...string) HandlerFunc { - var h fileSystems - for i := range dirs { - if dirs[i] != "" { - h = append(h, http.Dir(dirs[i])) - } - } - h = append(h, http.FS(f)) - return func(ctx Context) { - file, err := h.Open(ctx.GetParam("*")) - if err != nil { - ctx.Fatal(err) - return - } - stat, _ := file.Stat() - // embed.FS的ModTime()为空无法使用缓存,设置为启动时间使用304缓存机制。 - modtime := stat.ModTime() - if modtime.IsZero() { - modtime = DefaultEmbedTime - } - if ctx.Request().Header.Get(HeaderCacheControl) == "" { - ctx.SetHeader(HeaderCacheControl, DefaultEmbedCacheControl) - } - http.ServeContent(ctx.Response(), ctx.Request(), stat.Name(), modtime, file) - } -} - -// 组合多个http.FileSystem -type fileSystems []http.FileSystem - -func (fs fileSystems) Open(name string) (file http.File, err error) { - for _, f := range fs { - // 依次打开多个http.FileSystem返回一个成功打开的数据。 - file, err = f.Open(name) - if err == nil { - return - } - } - return -} diff --git a/value.go b/value.go new file mode 100644 index 0000000..75cd507 --- /dev/null +++ b/value.go @@ -0,0 +1,701 @@ +package eudore + +import ( + "encoding" + "fmt" + "reflect" + "strconv" + "strings" + "time" + "unsafe" +) + +type value struct { + Tags []string + Keys []string + Index int + All bool + Set bool + Value any + Pointers []uintptr + Pindex int +} + +// GetAnyByPath method A more path to get an attribute from an object. +// +// The path will be split using '.' and then look for the path in turn. +// +// Structure attributes can use the structure tag 'alias' to match attributes. +// +// Returns a null value if the match fails. +// +// 根据路径来从一个对象获得一个属性。 +// +// 路径将使用'.'分割,然后依次寻找路径。 +// +// 结构体属性可以使用结构体标签'alias'来匹配属性。 +// +// 如果匹配失败直接返回空值。 +func GetAnyByPath(i any, key string) any { + val, err := getValue(i, key, nil, false) + if err != nil { + return nil + } + return val.Interface() +} + +// GetAnyByPathWithTag 函数和GetAnyByPath函数相同,可以额外设置tags,同时会返回error。 +func GetAnyByPathWithTag(i any, key string, tags []string, all bool) (any, error) { + val, err := getValue(i, key, tags, all) + if err != nil { + return nil, err + } + if all { + val = reflect.NewAt(val.Type(), unsafe.Pointer(val.UnsafeAddr())).Elem() + } + return val.Interface(), nil +} + +// GetAnyByPathWithValue 函数和Get函数相同,可以允许查找私有属性并返回reflect.Value。 +func GetAnyByPathWithValue(i any, key string, tags []string, all bool) (reflect.Value, error) { + return getValue(i, key, tags, all) +} + +func getValue(i any, key string, tags []string, all bool) (reflect.Value, error) { + val, ok := i.(reflect.Value) + if !ok { + val = reflect.ValueOf(i) + } + if i == nil { + return val, ErrValueInputDataNil + } + if key == "" { + return val, nil + } + if tags == nil { + tags = DefaultValueGetSetTags + } + v := &value{ + Tags: tags, + Keys: strings.Split(key, "."), + All: all, + } + v.Pointers = make([]uintptr, 0, len(v.Keys)) + return v.getValue(val) +} + +// 从目标类型获取字符串路径的属性。 +func (v *value) getValue(iValue reflect.Value) (reflect.Value, error) { + if len(v.Keys) == v.Index { + return iValue, nil + } + if v.HasPointer(iValue) { + return iValue, v.newError(ErrFormatValueAnonymousField, iValue) + } + switch iValue.Kind() { + case reflect.Ptr, reflect.Interface: + if iValue.IsNil() { + return iValue, v.newError(ErrFormatValueTypeNil, iValue) + } + return v.getValue(iValue.Elem()) + case reflect.Struct: + return v.getStruct(iValue) + case reflect.Map: + return v.getMap(iValue) + case reflect.Array, reflect.Slice: + return v.getSlice(iValue) + } + return iValue, v.newError(ErrFormatValueNotField, iValue, v.Keys[v.Index]) +} + +// 处理结构体对象的读取。 +func (v *value) getStruct(iValue reflect.Value) (reflect.Value, error) { + field := getStructFieldOfTags(iValue, v.Keys[v.Index], v.Tags) + if field.Kind() == reflect.Invalid { + iType := iValue.Type() + for i := 0; i < iType.NumField(); i++ { + if iType.Field(i).Anonymous { + v2, err := v.getValue(iValue.Field(i)) + if err == nil { + return v2, nil + } + } + } + + return iValue, v.newError(ErrFormatValueNotField, iValue, v.Keys[v.Index]) + } + + if field.CanInterface() || v.All { + v.Index++ + defer func() { v.Index-- }() + return v.getValue(field) + } + return iValue, v.newError(ErrFormatValueStructUnexported, iValue, v.Keys[v.Index]) +} + +// 处理map读取属性。 +func (v *value) getMap(iValue reflect.Value) (reflect.Value, error) { + // 检测map是否为空 + if iValue.IsNil() { + return iValue, v.newError(ErrFormatValueTypeNil, iValue) + } + // 创建map需要的key + mapKey := reflect.New(iValue.Type().Key()).Elem() + err := setValueString(mapKey, v.Keys[v.Index]) + if err != nil { + return iValue, v.newError(ErrFormatValueMapIndexInvalid, iValue, v.Keys[v.Index]) + } + + // 获得map的value, 如果值无效则返回空。 + mapvalue := iValue.MapIndex(mapKey) + if mapvalue.Kind() == reflect.Invalid { + return iValue, v.newError(ErrFormatValueMapValueInvalid, iValue, v.Keys[v.Index]) + } + v.Index++ + defer func() { v.Index-- }() + return v.getValue(mapvalue) +} + +// 处理数组切片读取属性。 +func (v *value) getSlice(iValue reflect.Value) (reflect.Value, error) { + // 检测切片是否为空 + if iValue.Kind() == reflect.Slice && iValue.IsNil() { + return iValue, v.newError(ErrFormatValueTypeNil, iValue) + } + // 检测索引是否存在 + index, err := strconv.Atoi(v.Keys[v.Index]) + if err != nil || iValue.Len() <= index || iValue.Len() < -index { + return iValue, v.newError(ErrFormatValueArrayIndexInvalid, iValue, v.Keys[v.Index], iValue.Len()) + } else if index < 0 { + index += iValue.Len() + } + v.Index++ + defer func() { v.Index-- }() + return v.getValue(iValue.Index(index)) +} + +// The SetAnyByPath function sets the properties of an object, and the object must be a pointer type. +// +// The path will be separated using '.', and then the path will be searched for in sequence. +// +// When the object type selected in the path is ptr, it will be checked to see if it is empty. +// If the object is empty, it will be initialized by default. +// +// When the object type selected in the path is any, +// if the object is empty, it will be initialized to map[string]any, +// otherwise the next operation will be determined based on the value type. +// +// When the object type selected in the path is array, +// the path will be converted into an object index to set the array elements, +// and if the index is [], the elements will be appended. +// +// When the object type selected in the path is struct, +// the attribute name and attribute label 'alias' will be used to match when selecting attributes. +// +// If the value type is a string, it will be converted according to the set target type. +// +// If the target type is a string, the value will be output as a string and then assigned. +// +// SetAnyByPath 函数设置一个对象的属性,改对象必须是指针类型。 +// +// 路径将使用'.'分割,然后依次寻找路径。 +// +// 当路径中选择对象类型为ptr时,会检查是否为空,对象为空会默认进行初始化。 +// +// 当路径中选择对象类型为any时,如果对象为空会初始化为map[string]any, +// 否则按值类型来判断下一步操作。 +// +// 当路径中选择对象类型为array时,路径会转换成对象索引来设置数组元素,索引为[]则追加元素。 +// +// 当路径中选择对象类型为struct时,选择属性时会使用属性名称和属性标签'alias'来匹配。 +// +// 如果值的类型是字符串,会根据设置的目标类型来转换。 +// +// 如果目标类型是字符串,将会值输出成字符串然后赋值。 +func SetAnyByPath(i any, key string, val any) error { + return SetAnyByPathWithTag(i, key, val, nil, false) +} + +// SetAnyByPathWithTag 函数和SetAnyByPath函数相同,可以额外设置tags。 +func SetAnyByPathWithTag(i any, key string, val any, tags []string, all bool) error { + if i == nil || key == "" { + return ErrValueInputDataNil + } + iValue, ok := i.(reflect.Value) + if !ok { + iValue = reflect.ValueOf(i) + } + // 检测目标是指针类型。 + if iValue.Kind() != reflect.Ptr { + return ErrValueInputDataNotPtr + } + if tags == nil { + tags = DefaultValueGetSetTags + } + v := &value{ + Tags: tags, + Keys: strings.Split(key, "."), + All: all, + Set: true, + Value: val, + } + v.Pointers = make([]uintptr, 0, len(v.Keys)) + return v.setValue(iValue) +} + +func (v *value) setValue(iValue reflect.Value) error { + if len(v.Keys) == v.Index { + err := setValuePtr(reflect.ValueOf(v.Value), iValue) + if err != nil { + v.Index-- + err = v.newError("%s", iValue, err) + v.Index++ + } + return err + } + if v.HasPointer(iValue) { + return v.newError(ErrFormatValueAnonymousField, iValue) + } + switch iValue.Kind() { + case reflect.Ptr: + if iValue.IsNil() { + return v.setMake(iValue, reflect.New(iValue.Type().Elem())) + } + return v.setValue(iValue.Elem()) + case reflect.Interface: + return v.setInterface(iValue) + case reflect.Struct: + return v.setStruct(iValue) + case reflect.Map: + return v.setMap(iValue) + case reflect.Slice: + return v.setSlice(iValue) + case reflect.Array: + return v.setArray(iValue) + } + + return v.newError(ErrFormatValueNotField, iValue, v.Keys[v.Index]) +} + +func (v *value) setMake(iValue, newValue reflect.Value) error { + err := v.setValue(newValue) + if err == nil { + iValue.Set(newValue) + } + return err +} + +// 处理接口类型。 +func (v *value) setInterface(iValue reflect.Value) (err error) { + // 如果是空接口,初始化为map[string]any类型 + if iValue.IsNil() { + if iValue.Type() != typeAny { + return v.newError(ErrFormatValueTypeNil, iValue) + } + return v.setMake(iValue, reflect.ValueOf(make(map[string]any))) + } + // 创建一个可取地址的临时变量,并设置值用于下一步设置。 + newValue := reflect.New(iValue.Elem().Type()).Elem() + newValue.Set(iValue.Elem()) + err = v.setValue(newValue) + // 将修改后的值重新赋值给对象 + if err == nil { + iValue.Set(newValue) + } + return err +} + +// 处理结构体设置属性。 +func (v *value) setStruct(iValue reflect.Value) error { + field := getStructFieldOfTags(iValue, v.Keys[v.Index], v.Tags) + if field.Kind() == reflect.Invalid { + iType := iValue.Type() + for i := 0; i < iType.NumField(); i++ { + if iType.Field(i).Anonymous { + err := v.setValue(iValue.Field(i)) + if err == nil { + return nil + } + } + } + + return v.newError(ErrFormatValueNotField, iValue, v.Keys[v.Index]) + } + + if !field.CanSet() { + if v.All { + field = reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() + } else { + return v.newError(ErrFormatValueStructNotCanset, iValue, v.Keys[v.Index]) + } + } + v.Index++ + defer func() { v.Index-- }() + return v.setValue(field) +} + +// 处理map。 +func (v *value) setMap(iValue reflect.Value) error { + iType := iValue.Type() + // 对空map初始化 + if iValue.IsNil() { + return v.setMake(iValue, reflect.MakeMap(iType)) + } + + // 创建map需要匹配的key + mapKey := reflect.New(iType.Key()).Elem() + err := setValueString(mapKey, v.Keys[v.Index]) + if err != nil { + return v.newError(ErrFormatValueMapIndexInvalid, iValue, v.Keys[v.Index]) + } + + newValue := reflect.New(iType.Elem()).Elem() + mapvalue := iValue.MapIndex(mapKey) + if mapvalue.Kind() != reflect.Invalid { + newValue.Set(mapvalue) + } + + v.Index++ + defer func() { v.Index-- }() + err = v.setValue(newValue) + // 将修改后的mapvalue重新赋值给map + if err == nil { + iValue.SetMapIndex(mapKey, newValue) + } + return err +} + +func (v *value) setArray(iValue reflect.Value) error { + index, err := strconv.Atoi(v.Keys[v.Index]) + if err != nil || iValue.Len() <= index || iValue.Len() < -index { + return v.newError(ErrFormatValueArrayIndexInvalid, iValue, v.Keys[v.Index], iValue.Len()) + } else if index < 0 { + index += iValue.Len() + } + v.Index++ + defer func() { v.Index-- }() + return v.setValue(iValue.Index(index)) +} + +// 处理数组和切片。 +func (v *value) setSlice(iValue reflect.Value) error { + iType := iValue.Type() + // 处理空切片 + if iValue.IsNil() { + iValue.Set(reflect.MakeSlice(iType, 0, 4)) + err := v.setSlice(iValue) + if err != nil { + iValue.Set(reflect.Zero(iType)) + } + return err + } + + // 解析index + index, err := strconv.Atoi(v.Keys[v.Index]) + switch { + case (err != nil && v.Keys[v.Index] != "[]") || iValue.Len() < -index: + return v.newError(ErrFormatValueArrayIndexInvalid, iValue, v.Keys[v.Index], iValue.Len()) + case index < 0: + index += iValue.Len() + case v.Keys[v.Index] == "[]": + index = -1 + } + + // 创建新元素的类型和值 + newValue := reflect.New(iType.Elem()).Elem() + if index > -1 { + // 新建数组替换原数组扩容 + if iValue.Cap() <= index { + iValue.Set(reflect.AppendSlice(reflect.MakeSlice(iType, 0, index+1), iValue)) + } + // 对数组长度扩充,新元素添加空值 + if iValue.Len() <= index { + iValue.SetLen(index + 1) + } + // 将原数组值设置给newValue + newValue.Set(iValue.Index(index)) + } + + v.Index++ + defer func() { v.Index-- }() + err = v.setValue(newValue) + if err == nil { + if index > -1 { + iValue.Index(index).Set(newValue) + } else { + iValue.Set(reflect.Append(iValue, newValue)) + } + } + return err +} + +func (v *value) HasPointer(iValue reflect.Value) bool { + kind := iValue.Kind() + if kind < reflect.Map || kind > reflect.Slice { + return false + } + + ptr := iValue.Pointer() + if v.Pointers != nil && v.Index != v.Pindex { + v.Pindex = v.Index + v.Pointers = nil + } + + for _, p := range v.Pointers { + if p == ptr { + return true + } + } + v.Pointers = append(v.Pointers, ptr) + return false +} + +func (v *value) newError(f string, iValue reflect.Value, args ...any) error { + m := "get" + if v.Set { + m = "set" + } + + err := fmt.Errorf(fmt.Sprintf("%s type %s ", iValue.Kind(), iValue.Type())+f, args...) + return fmt.Errorf(ErrFormatValueError, m, strings.Join(v.Keys[:v.Index+1], "."), err) +} + +// 通过字符串获取结构体属性的索引。 +func getStructFieldOfTags(iValue reflect.Value, name string, tags []string) reflect.Value { + iType := iValue.Type() + for i := 0; i < iType.NumField(); i++ { + typeField := iType.Field(i) + // 字符串为结构体名称或结构体属性标签的值,则匹配返回索引。 + if typeField.Name == name { + return iValue.Field(i) + } + for _, tag := range tags { + if typeField.Tag.Get(tag) == name { + return iValue.Field(i) + } + } + } + return reflect.Value{} +} + +// getIndirectAllValue 函数获得解除引用的全部类型和值。 +func getIndirectAllValue(iValue reflect.Value) (types []reflect.Type, values []reflect.Value) { + for { + types = append(types, iValue.Type()) + values = append(values, iValue) + switch iValue.Kind() { + case reflect.Ptr, reflect.Interface: + if iValue.IsNil() { + return + } + iValue = iValue.Elem() + default: + return + } + } +} + +func setValuePtr(sValue reflect.Value, tValue reflect.Value) error { + if sValue.Kind() == reflect.Ptr || sValue.Kind() == reflect.Interface || + tValue.Kind() == reflect.Ptr || tValue.Kind() == reflect.Interface { + stypes, svalues := getIndirectAllValue(sValue) + ttypes, tvalues := getIndirectAllValue(tValue) + for i, ttype := range ttypes { + for j, stype := range stypes { + // 转换接口类型、相同类型、type别名类型 + if stype.ConvertibleTo(ttype) && tvalues[i].CanSet() { + return setValueData(svalues[j], tvalues[i]) + } + } + } + sValue = svalues[len(svalues)-1] + tValue = tvalues[len(tvalues)-1] + + // 目标类型如果是空指针,则尝试进行初始化并转换 + if tValue.Kind() == reflect.Ptr && tValue.IsNil() { + newValue := reflect.New(tValue.Type().Elem()) + err := setValuePtr(sValue, newValue) + if err == nil { + tValue.Set(newValue) + } + return err + } + } + return setValueData(sValue, tValue) +} + +func setValueData(sValue reflect.Value, tValue reflect.Value) error { + sType := sValue.Type() + tType := tValue.Type() + switch { + case sType == tType: + tValue.Set(sValue) + return nil + case sType.ConvertibleTo(tType): + tValue.Set(sValue.Convert(tType)) + return nil + case tValue.Kind() == reflect.Slice: + newValue := reflect.New(tValue.Type().Elem()).Elem() + err := setValueData(sValue, newValue) + if err == nil { + tValue.Set(reflect.Append(tValue, newValue)) + } + return err + case sType.Kind() == reflect.String: + return setValueString(tValue, strings.TrimSpace(sValue.String())) + case tType.Kind() == reflect.String: + tValue.SetString(fmt.Sprintf("%+v", sValue.Interface())) + return nil + } + return fmt.Errorf(ErrFormatValueSetWithValue, sValue.Type().String(), tValue.Type().String()) +} + +var bitSizes = [...]int{0, 0, 0, 8, 16, 32, 64, 0, 8, 16, 32, 64, 32, 64, 32, 64} + +// 使用字符串设置对象的值。 +// +//nolint:cyclop,gocyclo +func setValueString(v reflect.Value, s string) error { + var err error + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + err = setIntField(v, s) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + err = setUintField(v, s) + case reflect.Bool: + err = setBoolField(v, s) + case reflect.Float32, reflect.Float64: + err = setFloatField(v, s) + case reflect.Complex64, reflect.Complex128: + err = setComplexField(v, s) + case reflect.String: + v.SetString(s) + return nil + case reflect.Ptr: + if v.IsNil() { + newValue := reflect.New(v.Type().Elem()) + err := setValueString(newValue, s) + if err == nil { + v.Set(newValue) + } + return err + } + return setValueString(v.Elem(), s) + case reflect.Interface: + if v.IsNil() && v.Type() == typeAny { + v.Set(reflect.ValueOf(s)) + return nil + } + return setValueString(v.Elem(), s) + case reflect.Struct: + if v.Type().ConvertibleTo(typeTimeTime) { + return setTimeField(v, s) + } + return fmt.Errorf(ErrFormatValueSetStringUnknownType, v.Kind().String()) + default: + return fmt.Errorf(ErrFormatValueSetStringUnknownType, v.Kind().String()) + } + + if err != nil { + p := reflect.New(v.Type()) + e, ok := p.Interface().(encoding.TextUnmarshaler) + if ok { + err = e.UnmarshalText([]byte(s)) + } + if err == nil { + v.Set(p.Elem()) + } + } + return err +} + +func setIntField(field reflect.Value, str string) error { + if str == "" { + str = "0" + } + intVal, err := strconv.ParseInt(str, 10, bitSizes[int(field.Kind())]) + if err == nil { + field.SetInt(intVal) + } else if field.Type() == typeTimeDuration { + var t time.Duration + if t, err = time.ParseDuration(str); err == nil { + field.SetInt(int64(t)) + } + } + return err +} + +func setUintField(field reflect.Value, str string) error { + if str == "" { + str = "0" + } + uintVal, err := strconv.ParseUint(str, 10, bitSizes[int(field.Kind())]) + if err == nil { + field.SetUint(uintVal) + } + return err +} + +func setBoolField(field reflect.Value, str string) error { + if str == "" { + field.SetBool(true) + return nil + } + boolVal, err := strconv.ParseBool(str) + if err == nil { + field.SetBool(boolVal) + } + return err +} + +func setComplexField(field reflect.Value, str string) error { + str = strings.TrimSuffix(strings.TrimSuffix(strings.TrimPrefix(str, "("), "i"), ")") + pos := strings.Index(str, "+") + if pos == -1 { + pos = len(str) + str += "+0" + } + + read, err := strconv.ParseFloat(str[:pos], bitSizes[int(field.Kind())]) + if err != nil { + return err + } + image, err := strconv.ParseFloat(str[pos+1:], bitSizes[int(field.Kind())]) + if err != nil { + return err + } + + field.SetComplex(complex(read, image)) + return nil +} + +func setFloatField(field reflect.Value, str string) error { + if str == "" { + str = "0.0" + } + floatVal, err := strconv.ParseFloat(str, bitSizes[int(field.Kind())]) + if err == nil { + field.SetFloat(floatVal) + } + return err +} + +// TimeParse 方法通过解析内置支持的时间格式。 +func setTimeField(field reflect.Value, str string) (err error) { + var t time.Time + for i, f := range DefaultValueParseTimeFormats { + if DefaultValueParseTimeFixed[i] && len(str) != len(f) { + continue + } + t, err = time.Parse(f, str) + if err == nil { + if field.Type() != typeTimeTime { + field.Set(reflect.ValueOf(t).Convert(field.Type())) + } else { + field.Set(reflect.ValueOf(t)) + } + return + } + } + return +}