Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable add/remove routes after server starts #15

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) {
// for any additional routes on ipPort.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) {
p.addRoute(ipPort, httpHostMatch{match, dest})
func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) (routeID int) {
return p.addRoute(ipPort, httpHostMatch{match, dest})
}

type httpHostMatch struct {
Expand Down
17 changes: 8 additions & 9 deletions sni.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import (
// with AddStopACMESearch.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
p.AddSNIMatchRoute(ipPort, equals(sni), dest)
func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) (routeID int) {
return p.AddSNIMatchRoute(ipPort, equals(sni), dest)
}

// AddSNIMatchRoute appends a route to the ipPort listener that routes
Expand All @@ -48,16 +48,15 @@ func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
// with AddStopACMESearch.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) {
func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) (routeID int) {
cfg := p.configFor(ipPort)
if !cfg.stopACME {
if len(cfg.acmeTargets) == 0 {
p.addRoute(ipPort, &acmeMatch{cfg})
}
cfg.acmeTargets = append(cfg.acmeTargets, dest)
}

p.addRoute(ipPort, sniMatch{matcher, dest})
return p.addRoute(ipPort, sniMatch{matcher, dest})
}

// AddStopACMESearch prevents ACME probing of subsequent SNI routes.
Expand All @@ -74,7 +73,7 @@ type sniMatch struct {
}

func (m sniMatch) match(br *bufio.Reader) Target {
if m.matcher(context.TODO(), clientHelloServerName(br)) {
if m.matcher(context.TODO(), ClientHelloServerName(br)) {
return m.target
}
return nil
Expand All @@ -88,7 +87,7 @@ type acmeMatch struct {
}

func (m *acmeMatch) match(br *bufio.Reader) Target {
sni := clientHelloServerName(br)
sni := ClientHelloServerName(br)
if !strings.HasSuffix(sni, ".acme.invalid") {
return nil
}
Expand Down Expand Up @@ -153,10 +152,10 @@ func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) {
ret = dest
}

// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
// ClientHelloServerName returns the SNI server name inside the TLS ClientHello,
// without consuming any bytes from br.
// On any error, the empty string is returned.
func clientHelloServerName(br *bufio.Reader) (sni string) {
func ClientHelloServerName(br *bufio.Reader) (sni string) {
const recordHeaderLen = 5
hdr, err := br.Peek(recordHeaderLen)
if err != nil {
Expand Down
85 changes: 67 additions & 18 deletions tcpproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import (
"io"
"log"
"net"
"sync"
"time"
)

Expand All @@ -79,6 +80,9 @@ type Proxy struct {
// function. If nil, net.Dial is used.
// The provided net is always "tcp".
ListenFunc func(net, laddr string) (net.Listener, error)

// defaultHandler handles unmatched traffic
defaultHandler Target
}

// Matcher reports whether hostname matches the Matcher's criteria.
Expand All @@ -93,7 +97,9 @@ func equals(want string) Matcher {

// config contains the proxying state for one listener.
type config struct {
routes []route
routes *sync.Map // map[int]route
nextRouteID int

acmeTargets []Target // accumulates targets that should be probed for acme.
stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets.
}
Expand Down Expand Up @@ -122,25 +128,58 @@ func (p *Proxy) configFor(ipPort string) *config {
p.configs = make(map[string]*config)
}
if p.configs[ipPort] == nil {
p.configs[ipPort] = &config{}
cfg := &config{}
cfg.routes = &sync.Map{}
cfg.nextRouteID = 1
p.configs[ipPort] = cfg
}
return p.configs[ipPort]
}

func (p *Proxy) addRoute(ipPort string, r route) {
cfg := p.configFor(ipPort)
cfg.routes = append(cfg.routes, r)
func (p *Proxy) addRoute(ipPort string, r route) (routeID int) {
var cfg *config
if p.donec != nil {
// NOTE: Do not create config file if the server is listening.
// This saves the handling of bringing up and tearing down
// listeners when add or remove route.
cfg = p.configs[ipPort]
} else {
cfg = p.configFor(ipPort)
}
if cfg != nil {
routeID = cfg.nextRouteID
cfg.nextRouteID++
cfg.routes.Store(routeID, r)
}
return
}

// SetDefaultHandler sets the default handler for proxy.
func (p *Proxy) SetDefaultHandler(t Target) {
p.defaultHandler = t
}

// AddRoute appends an always-matching route to the ipPort listener,
// directing any connection to dest.
// directing any connection to dest. The added route's id is returned
// for future removal. If routeID is zero, the route is not registered.
//
// This is generally used as either the only rule (for simple TCP
// proxies), or as the final fallback rule for an ipPort.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddRoute(ipPort string, dest Target) {
p.addRoute(ipPort, fixedTarget{dest})
func (p *Proxy) AddRoute(ipPort string, dest Target) (routeID int) {
return p.addRoute(ipPort, fixedTarget{dest})
}

// RemoveRoute removes an existing route for ipPort. If the route is
// not found, this is an no-op.
//
// Both AddRoute and RemoveRoute is go-routine safe.
func (p *Proxy) RemoveRoute(ipPort string, routeID int) {
cfg := p.configs[ipPort]
if cfg != nil {
cfg.routes.Delete(routeID)
}
}

type fixedTarget struct {
Expand Down Expand Up @@ -197,7 +236,7 @@ func (p *Proxy) Start() error {
return err
}
p.lns = append(p.lns, ln)
go p.serveListener(errc, ln, config.routes)
go p.serveListener(errc, ln, config)
}
go p.awaitFirstError(errc)
return nil
Expand All @@ -208,22 +247,24 @@ func (p *Proxy) awaitFirstError(errc <-chan error) {
close(p.donec)
}

func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) {
func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) {
for {
c, err := ln.Accept()
if err != nil {
ret <- err
return
}
go p.serveConn(c, routes)
go p.serveConn(c, cfg)
}
}

// serveConn runs in its own goroutine and matches c against routes.
// It returns whether it matched purely for testing.
func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
func (p *Proxy) serveConn(c net.Conn, cfg *config) bool {
br := bufio.NewReader(c)
for _, route := range routes {
var handled bool
cfg.routes.Range(func(k, v interface{}) bool {
route := v.(route)
if target := route.match(br); target != nil {
if n := br.Buffered(); n > 0 {
peeked, _ := br.Peek(br.Buffered())
Expand All @@ -233,13 +274,21 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
}
}
target.HandleConn(c)
return true
handled = true
return false // exit the iteration
}
return true
})
if !handled {
if p.defaultHandler != nil {
p.defaultHandler.HandleConn(c)
} else {
// TODO: hook for this?
log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
c.Close()
}
}
// TODO: hook for this?
log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
c.Close()
return false
return handled
}

// Conn is an incoming connection that has had some bytes read from it
Expand Down
Loading