From f89ab2b0e0d6f0443494ddf733fbb1f77ae2e898 Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Fri, 6 Apr 2018 22:17:47 -0700 Subject: [PATCH 01/10] Implemented Proxy.RemoveRoute - Use a map instead of slice to hold routes for each configuration - Use a lock to guard read / write to the routes - Issue a serial number for the route added in order to remove them later - Add/RemoveRoute is thread-safe --- tcpproxy.go | 53 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/tcpproxy.go b/tcpproxy.go index 8c33604..fe7ab01 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -60,6 +60,7 @@ import ( "io" "log" "net" + "sync" "time" ) @@ -93,11 +94,19 @@ func equals(want string) Matcher { // config contains the proxying state for one listener. type config struct { - routes []route + sync.Mutex // protect r/w of routes + nextRouteId int + routes map[int]route acmeTargets []Target // accumulates targets that should be probed for acme. stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets. } +func NewConfig() (cfg *config) { + cfg = &config{} + cfg.routes = make(map[int]route) + return +} + // A route matches a connection to a target. type route interface { // match examines the initial bytes of a connection, looking for a @@ -122,25 +131,43 @@ func (p *Proxy) configFor(ipPort string) *config { p.configs = make(map[string]*config) } if p.configs[ipPort] == nil { - p.configs[ipPort] = &config{} + p.configs[ipPort] = NewConfig() } return p.configs[ipPort] } -func (p *Proxy) addRoute(ipPort string, r route) { +func (p *Proxy) addRoute(ipPort string, r route) (routeId int) { cfg := p.configFor(ipPort) - cfg.routes = append(cfg.routes, r) + cfg.Lock() + defer cfg.Unlock() + routeId = cfg.nextRouteId + cfg.nextRouteId++ + cfg.routes[routeId] = r + return } // AddRoute appends an always-matching route to the ipPort listener, -// directing any connection to dest. +// directing any connection to dest. The added route's id is returned +// for future removal. // // This is generally used as either the only rule (for simple TCP // proxies), or as the final fallback rule for an ipPort. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddRoute(ipPort string, dest Target) { - p.addRoute(ipPort, fixedTarget{dest}) +func (p *Proxy) AddRoute(ipPort string, dest Target) (routeId int) { + return p.addRoute(ipPort, fixedTarget{dest}) +} + +// RemoveRoute removes an existing route for ipPort. If the route is +// not found, this is an no-op. +// +// Both AddRoute and RemoveRoute is go-routine safe. +func (p *Proxy) RemoveRoute(ipPort string, routeId int) (err error) { + cfg := p.configFor(ipPort) + cfg.Lock() + defer cfg.Unlock() + delete(cfg.routes, routeId) + return } type fixedTarget struct { @@ -197,7 +224,7 @@ func (p *Proxy) Start() error { return err } p.lns = append(p.lns, ln) - go p.serveListener(errc, ln, config.routes) + go p.serveListener(errc, ln, config) } go p.awaitFirstError(errc) return nil @@ -208,22 +235,24 @@ func (p *Proxy) awaitFirstError(errc <-chan error) { close(p.donec) } -func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) { +func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) { for { c, err := ln.Accept() if err != nil { ret <- err return } - go p.serveConn(c, routes) + go p.serveConn(c, cfg) } } // serveConn runs in its own goroutine and matches c against routes. // It returns whether it matched purely for testing. -func (p *Proxy) serveConn(c net.Conn, routes []route) bool { +func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { br := bufio.NewReader(c) - for _, route := range routes { + cfg.Lock() + defer cfg.Unlock() + for _, route := range cfg.routes { if target := route.match(br); target != nil { if n := br.Buffered(); n > 0 { peeked, _ := br.Peek(br.Buffered()) From 7c3486f1a006517010a72ba2c41ab1ca5504195f Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Fri, 6 Apr 2018 22:26:52 -0700 Subject: [PATCH 02/10] Propagate the new addRoute API to SNI and HTTP host --- http.go | 4 ++-- sni.go | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/http.go b/http.go index 6197da9..1d3e422 100644 --- a/http.go +++ b/http.go @@ -37,8 +37,8 @@ func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) { // for any additional routes on ipPort. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) { - p.addRoute(ipPort, httpHostMatch{match, dest}) +func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) (routeId int) { + return p.addRoute(ipPort, httpHostMatch{match, dest}) } type httpHostMatch struct { diff --git a/sni.go b/sni.go index 44f5796..bb0a446 100644 --- a/sni.go +++ b/sni.go @@ -34,8 +34,8 @@ import ( // with AddStopACMESearch. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { - p.AddSNIMatchRoute(ipPort, equals(sni), dest) +func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) (routeId int) { + return p.AddSNIMatchRoute(ipPort, equals(sni), dest) } // AddSNIMatchRoute appends a route to the ipPort listener that routes @@ -48,7 +48,7 @@ func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { // with AddStopACMESearch. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) { +func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) (routeId int) { cfg := p.configFor(ipPort) if !cfg.stopACME { if len(cfg.acmeTargets) == 0 { @@ -56,8 +56,7 @@ func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) { } cfg.acmeTargets = append(cfg.acmeTargets, dest) } - - p.addRoute(ipPort, sniMatch{matcher, dest}) + return p.addRoute(ipPort, sniMatch{matcher, dest}) } // AddStopACMESearch prevents ACME probing of subsequent SNI routes. From 26f89279c8b5ada47124e75f6176b82a56fbab00 Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Fri, 6 Apr 2018 22:44:08 -0700 Subject: [PATCH 03/10] Extract a test helper to correct front -> back msg passing --- tcpproxy_test.go | 70 ++++++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 44 deletions(-) diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 682214d..7a62185 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -169,6 +169,26 @@ func testProxy(t *testing.T, front net.Listener) *Proxy { } } +func testRouteToBackendWithExpected(t *testing.T, front net.Conn, back net.Listener, msg string, expected string) { + io.WriteString(front, msg) + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != expected { + t.Fatalf("got %q; want %q", buf, expected) + } +} + +func testRouteToBackend(t *testing.T, front net.Conn, back net.Listener, msg string) { + testRouteToBackendWithExpected(t, front, back, msg, msg) +} + func TestProxyAlwaysMatch(t *testing.T) { front := newLocalListener(t) defer front.Close() @@ -187,20 +207,7 @@ func TestProxyAlwaysMatch(t *testing.T) { } defer toFront.Close() - fromProxy, err := back.Accept() - if err != nil { - t.Fatal(err) - } - const msg = "message" - io.WriteString(toFront, msg) - - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { - t.Fatal(err) - } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + testRouteToBackend(t, toFront, back, "message") } func TestProxyHTTP(t *testing.T) { @@ -226,20 +233,7 @@ func TestProxyHTTP(t *testing.T) { defer toFront.Close() const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" - io.WriteString(toFront, msg) - - fromProxy, err := backBar.Accept() - if err != nil { - t.Fatal(err) - } - - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { - t.Fatal(err) - } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + testRouteToBackend(t, toFront, backBar, msg) } func TestProxySNI(t *testing.T) { @@ -264,6 +258,9 @@ func TestProxySNI(t *testing.T) { } defer toFront.Close() + msg := clientHelloRecord(t, "bar.com") + testRouteToBackend(t, toFront, backBar, msg) +} msg := clientHelloRecord(t, "bar.com") io.WriteString(toFront, msg) @@ -301,23 +298,8 @@ func TestProxyPROXYOut(t *testing.T) { t.Fatal(err) } - io.WriteString(toFront, "foo") - toFront.Close() - - fromProxy, err := back.Accept() - if err != nil { - t.Fatal(err) - } - - bs, err := ioutil.ReadAll(fromProxy) - if err != nil { - t.Fatal(err) - } - want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port) - if string(bs) != want { - t.Fatalf("got %q; want %q", bs, want) - } + testRouteToBackendWithExpected(t, toFront, back, "foo", want) } type tlsServer struct { From cfba6a1d5614adfeb0da9e1de43f5c457b223f8d Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Fri, 6 Apr 2018 23:14:02 -0700 Subject: [PATCH 04/10] Handle the case when adding route after server starts - returns null routeId to indicate a failed registration --- tcpproxy.go | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/tcpproxy.go b/tcpproxy.go index fe7ab01..fcc4b5d 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -94,7 +94,7 @@ func equals(want string) Matcher { // config contains the proxying state for one listener. type config struct { - sync.Mutex // protect r/w of routes + sync.Mutex // protect w of routes nextRouteId int routes map[int]route acmeTargets []Target // accumulates targets that should be probed for acme. @@ -104,6 +104,7 @@ type config struct { func NewConfig() (cfg *config) { cfg = &config{} cfg.routes = make(map[int]route) + cfg.nextRouteId = 1 return } @@ -137,18 +138,28 @@ func (p *Proxy) configFor(ipPort string) *config { } func (p *Proxy) addRoute(ipPort string, r route) (routeId int) { - cfg := p.configFor(ipPort) - cfg.Lock() - defer cfg.Unlock() - routeId = cfg.nextRouteId - cfg.nextRouteId++ - cfg.routes[routeId] = r + var cfg *config + if p.donec != nil { + // NOTE: Do not create config file if the server is listening. + // This saves the handling of bringing up and tearing down + // listeners when add or remove route. + cfg = p.configs[ipPort] + } else { + cfg = p.configFor(ipPort) + } + if cfg != nil { + cfg.Lock() + routeId = cfg.nextRouteId + cfg.nextRouteId++ + cfg.routes[routeId] = r + cfg.Unlock() + } return } // AddRoute appends an always-matching route to the ipPort listener, // directing any connection to dest. The added route's id is returned -// for future removal. +// for future removal. If routeId is zero, the route is not registered. // // This is generally used as either the only rule (for simple TCP // proxies), or as the final fallback rule for an ipPort. @@ -164,9 +175,7 @@ func (p *Proxy) AddRoute(ipPort string, dest Target) (routeId int) { // Both AddRoute and RemoveRoute is go-routine safe. func (p *Proxy) RemoveRoute(ipPort string, routeId int) (err error) { cfg := p.configFor(ipPort) - cfg.Lock() - defer cfg.Unlock() - delete(cfg.routes, routeId) + cfg.routes[routeId] = nil return } @@ -250,9 +259,10 @@ func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) { // It returns whether it matched purely for testing. func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { br := bufio.NewReader(c) - cfg.Lock() - defer cfg.Unlock() for _, route := range cfg.routes { + if route == nil { + continue + } if target := route.match(br); target != nil { if n := br.Buffered(); n > 0 { peeked, _ := br.Peek(br.Buffered()) From cfffc790a765a28f643bedb2d53cb87803fb1feb Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Sat, 7 Apr 2018 00:27:13 -0700 Subject: [PATCH 05/10] Added the test helper to assert no message reached backend - Add the test case for existing backends - Add the test case for add/remove route after server starts --- tcpproxy_test.go | 144 ++++++++++++++++++++++++++++++----------------- 1 file changed, 93 insertions(+), 51 deletions(-) diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 7a62185..e98a241 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -169,24 +169,75 @@ func testProxy(t *testing.T, front net.Listener) *Proxy { } } -func testRouteToBackendWithExpected(t *testing.T, front net.Conn, back net.Listener, msg string, expected string) { - io.WriteString(front, msg) - fromProxy, err := back.Accept() - if err != nil { - t.Fatal(err) - } - - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { - t.Fatal(err) - } - if string(buf) != expected { - t.Fatalf("got %q; want %q", buf, expected) - } +func testRouteToBackendWithExpected(t *testing.T, toFront net.Conn, back net.Listener, msg string, expected string) { + io.WriteString(toFront, msg) + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(expected)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != expected { + t.Fatalf("got %q; want %q", buf, expected) + } } -func testRouteToBackend(t *testing.T, front net.Conn, back net.Listener, msg string) { - testRouteToBackendWithExpected(t, front, back, msg, msg) +func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) { + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + testRouteToBackendWithExpected(t, toFront, back, msg, msg) +} + +// test the backend is not receiving traffic +func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool { + done := make(chan bool) + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + timeC := time.NewTimer(10 * time.Millisecond).C + acceptC := make(chan struct{}) + go func() { + io.WriteString(toFront, msg) + fromProxy, err := back.Accept() + acceptC <- struct{}{} + { + if err == nil { + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + t.Fatalf("Expect backend to not receive message, but found %s", string(buf)) + } + err, ok := err.(net.Error) + if !ok || !err.Timeout() { + t.Fatalf("Expect backend to timeout, but found err: %v", err) + } + } + }() + go func() { + select { + case <-timeC: + { + done <- true + } + case <-acceptC: + { + t.Fatal("Expect backend to not receive message") + done <- true + } + } + }() + return done } func TestProxyAlwaysMatch(t *testing.T) { @@ -201,13 +252,7 @@ func TestProxyAlwaysMatch(t *testing.T) { t.Fatal(err) } - toFront, err := net.Dial("tcp", front.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer toFront.Close() - - testRouteToBackend(t, toFront, back, "message") + testRouteToBackend(t, front, back, "message") } func TestProxyHTTP(t *testing.T) { @@ -226,14 +271,9 @@ func TestProxyHTTP(t *testing.T) { t.Fatal(err) } - toFront, err := net.Dial("tcp", front.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer toFront.Close() - - const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" - testRouteToBackend(t, toFront, backBar, msg) + testRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n") + <-testNotRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: boo.com\r\n\r\n") + testRouteToBackend(t, front, backFoo, "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n") } func TestProxySNI(t *testing.T) { @@ -252,30 +292,32 @@ func TestProxySNI(t *testing.T) { t.Fatal(err) } - toFront, err := net.Dial("tcp", front.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer toFront.Close() - - msg := clientHelloRecord(t, "bar.com") - testRouteToBackend(t, toFront, backBar, msg) + testRouteToBackend(t, front, backBar, clientHelloRecord(t, "bar.com")) + <-testNotRouteToBackend(t, front, backBar, clientHelloRecord(t, "foo.com")) + testRouteToBackend(t, front, backFoo, clientHelloRecord(t, "foo.com")) } - msg := clientHelloRecord(t, "bar.com") - io.WriteString(toFront, msg) - fromProxy, err := backBar.Accept() - if err != nil { - t.Fatal(err) - } +func TestProxyRemoveRoute(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + p := testProxy(t, front) - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { + // NOTE: Needs to register testFrontAddr before server starts + p.AddSNIRoute(testFrontAddr, "unused.com", noopTarget{}) + + if err := p.Start(); err != nil { t.Fatal(err) } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + + backBar := newLocalListener(t) + defer backBar.Close() + routeId := p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + + msg := clientHelloRecord(t, "bar.com") + testRouteToBackend(t, front, backBar, msg) + + p.RemoveRoute(testFrontAddr, routeId) + <-testNotRouteToBackend(t, front, backBar, msg) } func TestProxyPROXYOut(t *testing.T) { @@ -299,7 +341,7 @@ func TestProxyPROXYOut(t *testing.T) { } want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port) - testRouteToBackendWithExpected(t, toFront, back, "foo", want) + testRouteToBackendWithExpected(t, toFront, back, "foo", want) } type tlsServer struct { From a64c39caa18719b71fb0acff3920f9b8350ea9cc Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Sat, 7 Apr 2018 18:52:45 -0700 Subject: [PATCH 06/10] Correct linter warnings --- http.go | 2 +- sni.go | 4 ++-- tcpproxy.go | 24 ++++++++++++------------ tcpproxy_test.go | 4 ++-- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/http.go b/http.go index 1d3e422..5fb1b5b 100644 --- a/http.go +++ b/http.go @@ -37,7 +37,7 @@ func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) { // for any additional routes on ipPort. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) (routeId int) { +func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) (routeID int) { return p.addRoute(ipPort, httpHostMatch{match, dest}) } diff --git a/sni.go b/sni.go index bb0a446..b49865c 100644 --- a/sni.go +++ b/sni.go @@ -34,7 +34,7 @@ import ( // with AddStopACMESearch. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) (routeId int) { +func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) (routeID int) { return p.AddSNIMatchRoute(ipPort, equals(sni), dest) } @@ -48,7 +48,7 @@ func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) (routeId int) { // with AddStopACMESearch. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) (routeId int) { +func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) (routeID int) { cfg := p.configFor(ipPort) if !cfg.stopACME { if len(cfg.acmeTargets) == 0 { diff --git a/tcpproxy.go b/tcpproxy.go index fcc4b5d..b4aa767 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -95,16 +95,16 @@ func equals(want string) Matcher { // config contains the proxying state for one listener. type config struct { sync.Mutex // protect w of routes - nextRouteId int + nextRouteID int routes map[int]route acmeTargets []Target // accumulates targets that should be probed for acme. stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets. } -func NewConfig() (cfg *config) { +func newConfig() (cfg *config) { cfg = &config{} cfg.routes = make(map[int]route) - cfg.nextRouteId = 1 + cfg.nextRouteID = 1 return } @@ -132,12 +132,12 @@ func (p *Proxy) configFor(ipPort string) *config { p.configs = make(map[string]*config) } if p.configs[ipPort] == nil { - p.configs[ipPort] = NewConfig() + p.configs[ipPort] = newConfig() } return p.configs[ipPort] } -func (p *Proxy) addRoute(ipPort string, r route) (routeId int) { +func (p *Proxy) addRoute(ipPort string, r route) (routeID int) { var cfg *config if p.donec != nil { // NOTE: Do not create config file if the server is listening. @@ -149,9 +149,9 @@ func (p *Proxy) addRoute(ipPort string, r route) (routeId int) { } if cfg != nil { cfg.Lock() - routeId = cfg.nextRouteId - cfg.nextRouteId++ - cfg.routes[routeId] = r + routeID = cfg.nextRouteID + cfg.nextRouteID++ + cfg.routes[routeID] = r cfg.Unlock() } return @@ -159,13 +159,13 @@ func (p *Proxy) addRoute(ipPort string, r route) (routeId int) { // AddRoute appends an always-matching route to the ipPort listener, // directing any connection to dest. The added route's id is returned -// for future removal. If routeId is zero, the route is not registered. +// for future removal. If routeID is zero, the route is not registered. // // This is generally used as either the only rule (for simple TCP // proxies), or as the final fallback rule for an ipPort. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddRoute(ipPort string, dest Target) (routeId int) { +func (p *Proxy) AddRoute(ipPort string, dest Target) (routeID int) { return p.addRoute(ipPort, fixedTarget{dest}) } @@ -173,9 +173,9 @@ func (p *Proxy) AddRoute(ipPort string, dest Target) (routeId int) { // not found, this is an no-op. // // Both AddRoute and RemoveRoute is go-routine safe. -func (p *Proxy) RemoveRoute(ipPort string, routeId int) (err error) { +func (p *Proxy) RemoveRoute(ipPort string, routeID int) (err error) { cfg := p.configFor(ipPort) - cfg.routes[routeId] = nil + cfg.routes[routeID] = nil return } diff --git a/tcpproxy_test.go b/tcpproxy_test.go index e98a241..dd88253 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -311,12 +311,12 @@ func TestProxyRemoveRoute(t *testing.T) { backBar := newLocalListener(t) defer backBar.Close() - routeId := p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + routeID := p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) msg := clientHelloRecord(t, "bar.com") testRouteToBackend(t, front, backBar, msg) - p.RemoveRoute(testFrontAddr, routeId) + p.RemoveRoute(testFrontAddr, routeID) <-testNotRouteToBackend(t, front, backBar, msg) } From 08e1ca7db3a5c63dd6a3000bc5460a55192c38b4 Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Tue, 10 Apr 2018 18:33:13 -0700 Subject: [PATCH 07/10] Review fixes --- tcpproxy.go | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tcpproxy.go b/tcpproxy.go index b4aa767..9bbc1a4 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -95,19 +95,13 @@ func equals(want string) Matcher { // config contains the proxying state for one listener. type config struct { sync.Mutex // protect w of routes - nextRouteID int routes map[int]route + nextRouteID int + acmeTargets []Target // accumulates targets that should be probed for acme. stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets. } -func newConfig() (cfg *config) { - cfg = &config{} - cfg.routes = make(map[int]route) - cfg.nextRouteID = 1 - return -} - // A route matches a connection to a target. type route interface { // match examines the initial bytes of a connection, looking for a @@ -132,7 +126,10 @@ func (p *Proxy) configFor(ipPort string) *config { p.configs = make(map[string]*config) } if p.configs[ipPort] == nil { - p.configs[ipPort] = newConfig() + cfg := &config{} + cfg.routes = make(map[int]route) + cfg.nextRouteID = 1 + p.configs[ipPort] = cfg } return p.configs[ipPort] } @@ -173,10 +170,9 @@ func (p *Proxy) AddRoute(ipPort string, dest Target) (routeID int) { // not found, this is an no-op. // // Both AddRoute and RemoveRoute is go-routine safe. -func (p *Proxy) RemoveRoute(ipPort string, routeID int) (err error) { +func (p *Proxy) RemoveRoute(ipPort string, routeID int) { cfg := p.configFor(ipPort) cfg.routes[routeID] = nil - return } type fixedTarget struct { From d41a00555889cf1b88a4ffd1da26f2668dade088 Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Thu, 24 May 2018 20:21:34 -0700 Subject: [PATCH 08/10] use sync.Map to save routes --- tcpproxy.go | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tcpproxy.go b/tcpproxy.go index 9bbc1a4..ee7053d 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -94,8 +94,7 @@ func equals(want string) Matcher { // config contains the proxying state for one listener. type config struct { - sync.Mutex // protect w of routes - routes map[int]route + routes *sync.Map // map[int]route nextRouteID int acmeTargets []Target // accumulates targets that should be probed for acme. @@ -127,7 +126,7 @@ func (p *Proxy) configFor(ipPort string) *config { } if p.configs[ipPort] == nil { cfg := &config{} - cfg.routes = make(map[int]route) + cfg.routes = &sync.Map{} cfg.nextRouteID = 1 p.configs[ipPort] = cfg } @@ -145,11 +144,9 @@ func (p *Proxy) addRoute(ipPort string, r route) (routeID int) { cfg = p.configFor(ipPort) } if cfg != nil { - cfg.Lock() routeID = cfg.nextRouteID cfg.nextRouteID++ - cfg.routes[routeID] = r - cfg.Unlock() + cfg.routes.Store(routeID, r) } return } @@ -172,7 +169,7 @@ func (p *Proxy) AddRoute(ipPort string, dest Target) (routeID int) { // Both AddRoute and RemoveRoute is go-routine safe. func (p *Proxy) RemoveRoute(ipPort string, routeID int) { cfg := p.configFor(ipPort) - cfg.routes[routeID] = nil + cfg.routes.Delete(routeID) } type fixedTarget struct { @@ -255,10 +252,9 @@ func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) { // It returns whether it matched purely for testing. func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { br := bufio.NewReader(c) - for _, route := range cfg.routes { - if route == nil { - continue - } + var handled bool + cfg.routes.Range(func(k, v interface{}) bool { + route := v.(route) if target := route.match(br); target != nil { if n := br.Buffered(); n > 0 { peeked, _ := br.Peek(br.Buffered()) @@ -268,13 +264,17 @@ func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { } } target.HandleConn(c) - return true + handled = true + return false // exit the iteration } + return true + }) + if !handled { + // TODO: hook for this? + log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) + c.Close() } - // TODO: hook for this? - log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) - c.Close() - return false + return handled } // Conn is an incoming connection that has had some bytes read from it From 9a40e9da514883d00dfa63fc9bb52bff45d46b6d Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Wed, 6 Jun 2018 15:41:46 -0700 Subject: [PATCH 09/10] do not create the config during route deletion --- tcpproxy.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tcpproxy.go b/tcpproxy.go index ee7053d..6c357af 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -168,8 +168,10 @@ func (p *Proxy) AddRoute(ipPort string, dest Target) (routeID int) { // // Both AddRoute and RemoveRoute is go-routine safe. func (p *Proxy) RemoveRoute(ipPort string, routeID int) { - cfg := p.configFor(ipPort) - cfg.routes.Delete(routeID) + cfg := p.configs[ipPort] + if cfg != nil { + cfg.routes.Delete(routeID) + } } type fixedTarget struct { From 2041ee5cacf948bc4018419beed8d3ba750100b0 Mon Sep 17 00:00:00 2001 From: Steven Yang Date: Sun, 10 Jun 2018 20:06:43 -0700 Subject: [PATCH 10/10] added support for default handler --- sni.go | 8 ++++---- tcpproxy.go | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sni.go b/sni.go index b49865c..edfaf8f 100644 --- a/sni.go +++ b/sni.go @@ -73,7 +73,7 @@ type sniMatch struct { } func (m sniMatch) match(br *bufio.Reader) Target { - if m.matcher(context.TODO(), clientHelloServerName(br)) { + if m.matcher(context.TODO(), ClientHelloServerName(br)) { return m.target } return nil @@ -87,7 +87,7 @@ type acmeMatch struct { } func (m *acmeMatch) match(br *bufio.Reader) Target { - sni := clientHelloServerName(br) + sni := ClientHelloServerName(br) if !strings.HasSuffix(sni, ".acme.invalid") { return nil } @@ -152,10 +152,10 @@ func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) { ret = dest } -// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +// ClientHelloServerName returns the SNI server name inside the TLS ClientHello, // without consuming any bytes from br. // On any error, the empty string is returned. -func clientHelloServerName(br *bufio.Reader) (sni string) { +func ClientHelloServerName(br *bufio.Reader) (sni string) { const recordHeaderLen = 5 hdr, err := br.Peek(recordHeaderLen) if err != nil { diff --git a/tcpproxy.go b/tcpproxy.go index 6c357af..16eeed4 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -80,6 +80,9 @@ type Proxy struct { // function. If nil, net.Dial is used. // The provided net is always "tcp". ListenFunc func(net, laddr string) (net.Listener, error) + + // defaultHandler handles unmatched traffic + defaultHandler Target } // Matcher reports whether hostname matches the Matcher's criteria. @@ -151,6 +154,11 @@ func (p *Proxy) addRoute(ipPort string, r route) (routeID int) { return } +// SetDefaultHandler sets the default handler for proxy. +func (p *Proxy) SetDefaultHandler(t Target) { + p.defaultHandler = t +} + // AddRoute appends an always-matching route to the ipPort listener, // directing any connection to dest. The added route's id is returned // for future removal. If routeID is zero, the route is not registered. @@ -272,9 +280,13 @@ func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { return true }) if !handled { - // TODO: hook for this? - log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) - c.Close() + if p.defaultHandler != nil { + p.defaultHandler.HandleConn(c) + } else { + // TODO: hook for this? + log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) + c.Close() + } } return handled }