Skip to content

Commit

Permalink
feat: add WebSocket.NetConn
Browse files Browse the repository at this point in the history
  • Loading branch information
aofei committed Jan 22, 2021
1 parent f7ca0b3 commit 11dcec7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
9 changes: 9 additions & 0 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package air

import (
"io/ioutil"
"net"
"time"

"github.com/gorilla/websocket"
Expand Down Expand Up @@ -40,6 +41,14 @@ type WebSocket struct {
listened bool
}

// NetConn returns the underlying `net.Conn` of the ws.
//
// ATTENTION: You should never call this method unless you know what you are
// doing.
func (ws *WebSocket) NetConn() net.Conn {
return ws.conn.UnderlyingConn()
}

// SetMaxMessageBytes sets the maximum number of bytes allowed for the ws to
// read messages from the remote peer. If a message exceeds the limit, the ws
// sends a close message to the remote peer.
Expand Down
44 changes: 44 additions & 0 deletions websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
Expand All @@ -12,6 +13,49 @@ import (
"github.com/stretchr/testify/assert"
)

func TestWebSocketNetConn(t *testing.T) {
a := New()
a.Address = "localhost:0"

buf := bytes.Buffer{}
a.GET("/", func(req *Request, res *Response) error {
ws, err := res.WebSocket()
if err != nil {
return err
}

if _, err := io.Copy(&buf, ws.NetConn()); err != nil {
return err
}

return ws.Close()
})

hijackOSStdout()

go a.Serve()
defer a.Close()
time.Sleep(100 * time.Millisecond)

revertOSStdout()

conn, _, err := websocket.DefaultDialer.Dial(
"ws://"+a.Addresses()[0],
nil,
)
assert.NoError(t, err)
assert.NotNil(t, conn)
defer conn.Close()

time.Sleep(100 * time.Millisecond)

n, err := conn.UnderlyingConn().Write([]byte("Foobar"))
assert.NoError(t, err)
assert.Equal(t, 6, n)
time.Sleep(100 * time.Millisecond)
assert.Equal(t, "Foobar", buf.String())
}

func TestWebSocketSetMaxMessageBytes(t *testing.T) {
a := New()
a.Address = "localhost:0"
Expand Down

0 comments on commit 11dcec7

Please sign in to comment.