diff --git a/http.go b/http.go index 6197da9..5fb1b5b 100644 --- a/http.go +++ b/http.go @@ -37,8 +37,8 @@ func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) { // for any additional routes on ipPort. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) { - p.addRoute(ipPort, httpHostMatch{match, dest}) +func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) (routeID int) { + return p.addRoute(ipPort, httpHostMatch{match, dest}) } type httpHostMatch struct { diff --git a/sni.go b/sni.go index 44f5796..edfaf8f 100644 --- a/sni.go +++ b/sni.go @@ -34,8 +34,8 @@ import ( // with AddStopACMESearch. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { - p.AddSNIMatchRoute(ipPort, equals(sni), dest) +func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) (routeID int) { + return p.AddSNIMatchRoute(ipPort, equals(sni), dest) } // AddSNIMatchRoute appends a route to the ipPort listener that routes @@ -48,7 +48,7 @@ func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { // with AddStopACMESearch. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) { +func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) (routeID int) { cfg := p.configFor(ipPort) if !cfg.stopACME { if len(cfg.acmeTargets) == 0 { @@ -56,8 +56,7 @@ func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) { } cfg.acmeTargets = append(cfg.acmeTargets, dest) } - - p.addRoute(ipPort, sniMatch{matcher, dest}) + return p.addRoute(ipPort, sniMatch{matcher, dest}) } // AddStopACMESearch prevents ACME probing of subsequent SNI routes. @@ -74,7 +73,7 @@ type sniMatch struct { } func (m sniMatch) match(br *bufio.Reader) Target { - if m.matcher(context.TODO(), clientHelloServerName(br)) { + if m.matcher(context.TODO(), ClientHelloServerName(br)) { return m.target } return nil @@ -88,7 +87,7 @@ type acmeMatch struct { } func (m *acmeMatch) match(br *bufio.Reader) Target { - sni := clientHelloServerName(br) + sni := ClientHelloServerName(br) if !strings.HasSuffix(sni, ".acme.invalid") { return nil } @@ -153,10 +152,10 @@ func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) { ret = dest } -// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +// ClientHelloServerName returns the SNI server name inside the TLS ClientHello, // without consuming any bytes from br. // On any error, the empty string is returned. -func clientHelloServerName(br *bufio.Reader) (sni string) { +func ClientHelloServerName(br *bufio.Reader) (sni string) { const recordHeaderLen = 5 hdr, err := br.Peek(recordHeaderLen) if err != nil { diff --git a/tcpproxy.go b/tcpproxy.go index 8c33604..16eeed4 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -60,6 +60,7 @@ import ( "io" "log" "net" + "sync" "time" ) @@ -79,6 +80,9 @@ type Proxy struct { // function. If nil, net.Dial is used. // The provided net is always "tcp". ListenFunc func(net, laddr string) (net.Listener, error) + + // defaultHandler handles unmatched traffic + defaultHandler Target } // Matcher reports whether hostname matches the Matcher's criteria. @@ -93,7 +97,9 @@ func equals(want string) Matcher { // config contains the proxying state for one listener. type config struct { - routes []route + routes *sync.Map // map[int]route + nextRouteID int + acmeTargets []Target // accumulates targets that should be probed for acme. stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets. } @@ -122,25 +128,58 @@ func (p *Proxy) configFor(ipPort string) *config { p.configs = make(map[string]*config) } if p.configs[ipPort] == nil { - p.configs[ipPort] = &config{} + cfg := &config{} + cfg.routes = &sync.Map{} + cfg.nextRouteID = 1 + p.configs[ipPort] = cfg } return p.configs[ipPort] } -func (p *Proxy) addRoute(ipPort string, r route) { - cfg := p.configFor(ipPort) - cfg.routes = append(cfg.routes, r) +func (p *Proxy) addRoute(ipPort string, r route) (routeID int) { + var cfg *config + if p.donec != nil { + // NOTE: Do not create config file if the server is listening. + // This saves the handling of bringing up and tearing down + // listeners when add or remove route. + cfg = p.configs[ipPort] + } else { + cfg = p.configFor(ipPort) + } + if cfg != nil { + routeID = cfg.nextRouteID + cfg.nextRouteID++ + cfg.routes.Store(routeID, r) + } + return +} + +// SetDefaultHandler sets the default handler for proxy. +func (p *Proxy) SetDefaultHandler(t Target) { + p.defaultHandler = t } // AddRoute appends an always-matching route to the ipPort listener, -// directing any connection to dest. +// directing any connection to dest. The added route's id is returned +// for future removal. If routeID is zero, the route is not registered. // // This is generally used as either the only rule (for simple TCP // proxies), or as the final fallback rule for an ipPort. // // The ipPort is any valid net.Listen TCP address. -func (p *Proxy) AddRoute(ipPort string, dest Target) { - p.addRoute(ipPort, fixedTarget{dest}) +func (p *Proxy) AddRoute(ipPort string, dest Target) (routeID int) { + return p.addRoute(ipPort, fixedTarget{dest}) +} + +// RemoveRoute removes an existing route for ipPort. If the route is +// not found, this is an no-op. +// +// Both AddRoute and RemoveRoute is go-routine safe. +func (p *Proxy) RemoveRoute(ipPort string, routeID int) { + cfg := p.configs[ipPort] + if cfg != nil { + cfg.routes.Delete(routeID) + } } type fixedTarget struct { @@ -197,7 +236,7 @@ func (p *Proxy) Start() error { return err } p.lns = append(p.lns, ln) - go p.serveListener(errc, ln, config.routes) + go p.serveListener(errc, ln, config) } go p.awaitFirstError(errc) return nil @@ -208,22 +247,24 @@ func (p *Proxy) awaitFirstError(errc <-chan error) { close(p.donec) } -func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) { +func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) { for { c, err := ln.Accept() if err != nil { ret <- err return } - go p.serveConn(c, routes) + go p.serveConn(c, cfg) } } // serveConn runs in its own goroutine and matches c against routes. // It returns whether it matched purely for testing. -func (p *Proxy) serveConn(c net.Conn, routes []route) bool { +func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { br := bufio.NewReader(c) - for _, route := range routes { + var handled bool + cfg.routes.Range(func(k, v interface{}) bool { + route := v.(route) if target := route.match(br); target != nil { if n := br.Buffered(); n > 0 { peeked, _ := br.Peek(br.Buffered()) @@ -233,13 +274,21 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool { } } target.HandleConn(c) - return true + handled = true + return false // exit the iteration + } + return true + }) + if !handled { + if p.defaultHandler != nil { + p.defaultHandler.HandleConn(c) + } else { + // TODO: hook for this? + log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) + c.Close() } } - // TODO: hook for this? - log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) - c.Close() - return false + return handled } // Conn is an incoming connection that has had some bytes read from it diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 682214d..dd88253 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -169,38 +169,90 @@ func testProxy(t *testing.T, front net.Listener) *Proxy { } } -func TestProxyAlwaysMatch(t *testing.T) { - front := newLocalListener(t) - defer front.Close() - back := newLocalListener(t) - defer back.Close() +func testRouteToBackendWithExpected(t *testing.T, toFront net.Conn, back net.Listener, msg string, expected string) { + io.WriteString(toFront, msg) + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } - p := testProxy(t, front) - p.AddRoute(testFrontAddr, To(back.Addr().String())) - if err := p.Start(); err != nil { + buf := make([]byte, len(expected)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { t.Fatal(err) } + if string(buf) != expected { + t.Fatalf("got %q; want %q", buf, expected) + } +} +func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) { toFront, err := net.Dial("tcp", front.Addr().String()) if err != nil { t.Fatal(err) } defer toFront.Close() - fromProxy, err := back.Accept() + testRouteToBackendWithExpected(t, toFront, back, msg, msg) +} + +// test the backend is not receiving traffic +func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool { + done := make(chan bool) + toFront, err := net.Dial("tcp", front.Addr().String()) if err != nil { t.Fatal(err) } - const msg = "message" - io.WriteString(toFront, msg) + defer toFront.Close() - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { + timeC := time.NewTimer(10 * time.Millisecond).C + acceptC := make(chan struct{}) + go func() { + io.WriteString(toFront, msg) + fromProxy, err := back.Accept() + acceptC <- struct{}{} + { + if err == nil { + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + t.Fatalf("Expect backend to not receive message, but found %s", string(buf)) + } + err, ok := err.(net.Error) + if !ok || !err.Timeout() { + t.Fatalf("Expect backend to timeout, but found err: %v", err) + } + } + }() + go func() { + select { + case <-timeC: + { + done <- true + } + case <-acceptC: + { + t.Fatal("Expect backend to not receive message") + done <- true + } + } + }() + return done +} + +func TestProxyAlwaysMatch(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { t.Fatal(err) } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + + testRouteToBackend(t, front, back, "message") } func TestProxyHTTP(t *testing.T) { @@ -219,27 +271,9 @@ func TestProxyHTTP(t *testing.T) { t.Fatal(err) } - toFront, err := net.Dial("tcp", front.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer toFront.Close() - - const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" - io.WriteString(toFront, msg) - - fromProxy, err := backBar.Accept() - if err != nil { - t.Fatal(err) - } - - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { - t.Fatal(err) - } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + testRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n") + <-testNotRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: boo.com\r\n\r\n") + testRouteToBackend(t, front, backFoo, "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n") } func TestProxySNI(t *testing.T) { @@ -258,27 +292,32 @@ func TestProxySNI(t *testing.T) { t.Fatal(err) } - toFront, err := net.Dial("tcp", front.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer toFront.Close() + testRouteToBackend(t, front, backBar, clientHelloRecord(t, "bar.com")) + <-testNotRouteToBackend(t, front, backBar, clientHelloRecord(t, "foo.com")) + testRouteToBackend(t, front, backFoo, clientHelloRecord(t, "foo.com")) +} - msg := clientHelloRecord(t, "bar.com") - io.WriteString(toFront, msg) +func TestProxyRemoveRoute(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + p := testProxy(t, front) - fromProxy, err := backBar.Accept() - if err != nil { - t.Fatal(err) - } + // NOTE: Needs to register testFrontAddr before server starts + p.AddSNIRoute(testFrontAddr, "unused.com", noopTarget{}) - buf := make([]byte, len(msg)) - if _, err := io.ReadFull(fromProxy, buf); err != nil { + if err := p.Start(); err != nil { t.Fatal(err) } - if string(buf) != msg { - t.Fatalf("got %q; want %q", buf, msg) - } + + backBar := newLocalListener(t) + defer backBar.Close() + routeID := p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + + msg := clientHelloRecord(t, "bar.com") + testRouteToBackend(t, front, backBar, msg) + + p.RemoveRoute(testFrontAddr, routeID) + <-testNotRouteToBackend(t, front, backBar, msg) } func TestProxyPROXYOut(t *testing.T) { @@ -301,23 +340,8 @@ func TestProxyPROXYOut(t *testing.T) { t.Fatal(err) } - io.WriteString(toFront, "foo") - toFront.Close() - - fromProxy, err := back.Accept() - if err != nil { - t.Fatal(err) - } - - bs, err := ioutil.ReadAll(fromProxy) - if err != nil { - t.Fatal(err) - } - want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port) - if string(bs) != want { - t.Fatalf("got %q; want %q", bs, want) - } + testRouteToBackendWithExpected(t, toFront, back, "foo", want) } type tlsServer struct {