Skip to content

Commit

Permalink
Add DAITA support
Browse files Browse the repository at this point in the history
Add support for DAITA using `maybenot`.
---------

Co-authored-by: Sebastian Holmin <[email protected]>
Co-authored-by: Joakim Hulthe <[email protected]>
Co-authored-by: David Lönnhager <[email protected]>
Co-authored-by: Markus Pettersson <[email protected]>
  • Loading branch information
4 people committed Jun 25, 2024
1 parent 2163620 commit f4bc3ae
Show file tree
Hide file tree
Showing 11 changed files with 460 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
wireguard-go
libmaybenot.a
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "maybenot"]
path = maybenot
url = https://github.com/mullvad/maybenot
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
PREFIX ?= /usr
DESTDIR ?=
LIBDEST ?= $(CURDIR)
BINDIR ?= $(PREFIX)/bin
TARGET ?=
export GO111MODULE := on

all: generate-version-and-build
Expand All @@ -22,10 +24,17 @@ wireguard-go: $(wildcard *.go) $(wildcard */*.go)
install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"

daita: libmaybenot.a
go build --tags daita -v -o wireguard-go

libmaybenot.a: $(wildcard maybenot/*)
make --directory maybenot/crates/maybenot-ffi/ DESTINATION=$(LIBDEST) TARGET=$(TARGET)

test:
go test ./...

clean:
rm -f wireguard-go
rm -f libmaybenot.a

.PHONY: all clean test install generate-version-and-build
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Go Implementation of [WireGuard](https://www.wireguard.com/)
# Go Implementation of [WireGuard](https://www.wireguard.com/) - Mullvad VPN fork

This is an implementation of WireGuard in Go.
This is an implementation of WireGuard in Go with support for [DAITA](#daita).

## Usage

Expand Down Expand Up @@ -54,6 +54,20 @@ $ cd wireguard-go
$ make
```

### DAITA

[DAITA](https://mullvad.net/en/blog/introducing-defense-against-ai-guided-traffic-analysis-daita) is a Mullvad-specific addition to wireguard-go which integrates the [maybenot](https://github.com/maybenot-io/maybenot) framework for traffic analysis defenses. To build wireguard-go with DAITA you need to initialize the `maybenot` submodule.

```
git submodule update --init
```

Then build `wireguard-go` with DAITA support

```
make daita
```

## License

Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
Expand Down
296 changes: 296 additions & 0 deletions device/daita.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
//go:build daita
// +build daita

package device

import (
"encoding/binary"
"sync"
"time"
"unsafe"
)

// #include <stdio.h>
// #include <stdlib.h>
// #include "../maybenot/crates/maybenot-ffi/maybenot.h"
// #cgo LDFLAGS: -L${SRCDIR}/../ -lmaybenot -lm
import "C"

type MaybenotDaita struct {
events chan Event
actions chan Action
maybenot *C.MaybenotFramework
newActionsBuf []C.MaybenotAction
paddingQueue map[uint64]*time.Timer // Map from machine to queued padding packets
logger *Logger
stopping sync.WaitGroup // waitgroup for handleEvents and HandleDaitaActions
}

type Event struct {
// The machine that generated the action that generated this event, if any.
Machine uint64

Peer NoisePublicKey
EventType EventType
XmitBytes uint16
}

type ActionType uint32

const (
ActionTypeCancel ActionType = iota
ActionTypeInjectPadding
ActionTypeBlockOutgoing
)

const (
ERROR_GENERAL_FAILURE = -1
ERROR_INTERMITTENT_FAILURE = -2
)

type Action struct {
ActionType ActionType

// The maybenot machine that generated the action.
// Should be propagated back by events generated by this action.
Machine uint64

// The time at which the action should be performed
Timeout time.Duration

// Information about the padding action
Payload Padding
}

type Padding struct {
// The size of the padding packet, in bytes. NOT including the Daita header.
ByteCount uint16
Replace bool
}

func (peer *Peer) EnableDaita(machines string, eventsCapacity uint, actionsCapacity uint, maxPaddingBytes float64, maxBlockingBytes float64) bool {
peer.Lock()
defer peer.Unlock()

if !peer.isRunning.Load() {
return false
}

if peer.daita != nil {
peer.device.log.Errorf("Failed to activate DAITA as it is already active")
return false
}

peer.device.log.Verbosef("Enabling DAITA for peer: %v", peer)

mtu := peer.device.tun.mtu.Load()

peer.device.log.Verbosef("MTU %v", mtu)
var maybenot *C.MaybenotFramework
c_machines := C.CString(machines)

c_maxPaddingBytes := C.double(maxPaddingBytes)
c_maxBlockingBytes := C.double(maxBlockingBytes)

maybenot_result := C.maybenot_start(
c_machines, c_maxPaddingBytes, c_maxBlockingBytes, C.ushort(mtu),
&maybenot,
)
C.free(unsafe.Pointer(c_machines))

if maybenot_result != 0 {
peer.device.log.Errorf("Failed to initialize maybenot, code=%d", maybenot_result)
return false
}

numMachines := C.maybenot_num_machines(maybenot)
daita := MaybenotDaita{
events: make(chan Event, eventsCapacity),
maybenot: maybenot,
newActionsBuf: make([]C.MaybenotAction, numMachines),
paddingQueue: map[uint64]*time.Timer{},
logger: peer.device.log,
}

daita.stopping.Add(1)
go daita.handleEvents(peer)
peer.daita = &daita

return true
}

// Stop the MaybenotDaita instance. It must not be used after calling this.
func (daita *MaybenotDaita) Close() {
daita.logger.Verbosef("Waiting for DAITA routines to stop")
close(daita.events)
for _, queuedPadding := range daita.paddingQueue {
if queuedPadding.Stop() {
daita.stopping.Done()
}
}
daita.stopping.Wait()
daita.logger.Verbosef("DAITA routines have stopped")
}

func (daita *MaybenotDaita) NonpaddingReceived(peer *Peer, packetLen uint) {
daita.event(peer, NonpaddingReceived, packetLen, 0)
}

func (daita *MaybenotDaita) PaddingReceived(peer *Peer, packetLen uint) {
daita.event(peer, PaddingReceived, packetLen, 0)
}

func (daita *MaybenotDaita) PaddingSent(peer *Peer, packetLen uint, machine uint64) {
daita.event(peer, PaddingSent, packetLen, machine)
}

func (daita *MaybenotDaita) NonpaddingSent(peer *Peer, packetLen uint) {
daita.event(peer, NonpaddingSent, packetLen, 0)
}

func (daita *MaybenotDaita) event(peer *Peer, eventType EventType, packetLen uint, machine uint64) {
if daita == nil {
return
}

event := Event{
Machine: machine,
Peer: peer.handshake.remoteStatic,
EventType: eventType,
XmitBytes: uint16(packetLen),
}

select {
case daita.events <- event:
default:
peer.device.log.Verbosef("Dropped DAITA event %v due to full buffer", event.EventType)
}
}

func injectPadding(action Action, peer *Peer) {
if action.ActionType != ActionTypeInjectPadding {
peer.device.log.Errorf("Got unknown action type %v", action.ActionType)
return
}

elem := peer.device.NewOutboundElement()

size := action.Payload.ByteCount
if size < DaitaHeaderLen || size > uint16(peer.device.tun.mtu.Load()) {
peer.device.log.Errorf("DAITA padding action contained invalid size %v bytes", size)
return
}

elem.packet = elem.buffer[MessageTransportHeaderSize : MessageTransportHeaderSize+int(size)]
elem.packet[0] = DaitaPaddingMarker
binary.BigEndian.PutUint16(elem.packet[DaitaOffsetTotalLength:DaitaOffsetTotalLength+2], size)

if peer.isRunning.Load() {
peer.StagePacket(elem)
elem = nil
peer.SendStagedPackets()

peer.daita.PaddingSent(peer, uint(size), action.Machine)
}
}

func (daita *MaybenotDaita) handleEvents(peer *Peer) {
defer func() {
C.maybenot_stop(daita.maybenot)
daita.stopping.Done()
daita.logger.Verbosef("%v - DAITA: event handler - stopped", peer)
}()

for {
event, more := <-daita.events
if !more {
return
}

daita.handleEvent(event, peer)
}
}

func (daita *MaybenotDaita) handleEvent(event Event, peer *Peer) {

for _, cAction := range daita.maybenotEventToActions(event) {
action := cActionToGo(cAction)

switch action.ActionType {
case ActionTypeCancel:
machine := action.Machine
// If padding is queued for the machine, cancel it
if queuedPadding, ok := daita.paddingQueue[machine]; ok {
if queuedPadding.Stop() {
daita.stopping.Done()
}
}
case ActionTypeInjectPadding:
// Check if a padding packet was already queued for the machine
// If so, try to cancel it
timer, paddingWasQueued := daita.paddingQueue[action.Machine]
// If no padding was queued, or the action fire before we manage to
// cancel it, we need to increment the wait group again
if !paddingWasQueued || !timer.Stop() {
daita.stopping.Add(1)
}

daita.paddingQueue[action.Machine] =
time.AfterFunc(action.Timeout, func() {
defer daita.stopping.Done()
injectPadding(action, peer)
})
case ActionTypeBlockOutgoing:
daita.logger.Errorf("ignoring action type ActionTypeBlockOutgoing, unimplemented")
continue
}
}
}

func (daita *MaybenotDaita) maybenotEventToActions(event Event) []C.MaybenotAction {
cEvent := C.MaybenotEvent{
machine: C.uintptr_t(event.Machine),
event_type: C.uint32_t(event.EventType),
xmit_bytes: C.uint16_t(event.XmitBytes),
}

var actionsWritten C.uintptr_t

// TODO: use unsafe.SliceData instead of the pointer dereference when the Go version gets bumped to 1.20 or later
// TODO: fetch an error string from the FFI corresponding to the error code
result := C.maybenot_on_events(daita.maybenot, &cEvent, 1, &daita.newActionsBuf[0], &actionsWritten)
if result != 0 {
daita.logger.Errorf("Failed to handle event as it was a null pointer\nEvent: %d\n", event)
return nil
}

newActions := daita.newActionsBuf[:actionsWritten]
return newActions
}

func cActionToGo(action_c C.MaybenotAction) Action {
if action_c.tag != C.MaybenotAction_InjectPadding {
panic("Unsupported tag")
}

// cast union to the ActionInjectPadding variant
padding_action := (*C.MaybenotAction_InjectPadding_Body)(unsafe.Pointer(&action_c.anon0[0]))

timeout := maybenotDurationToGoDuration(padding_action.timeout)

return Action{
Machine: uint64(padding_action.machine),
Timeout: timeout,
ActionType: ActionTypeInjectPadding,
Payload: Padding{
ByteCount: uint16(padding_action.size),
Replace: bool(padding_action.replace),
},
}
}

func maybenotDurationToGoDuration(duration C.MaybenotDuration) time.Duration {
// let's just assume this is fine...
nanoseconds := uint64(duration.secs)*1_000_000_000 + uint64(duration.nanos)
return time.Duration(nanoseconds)
}
Loading

0 comments on commit f4bc3ae

Please sign in to comment.