From 98dcba4f26963716e04e7ceeee2d5941c3dfa7a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Bach=C3=A9?= Date: Sat, 26 Jun 2021 12:05:39 +0200 Subject: [PATCH] Session's contexts can now be updated This is needed for DTLS restart, as mentioned in pion/webrtc#1636. --- README.md | 1 + session.go | 32 ++++++++++++++++++++++++++++---- session_srtcp.go | 4 +++- session_srtp.go | 4 +++- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 19551e3..bcc5afd 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu * [Mission Liao](https://github.com/mission-liao) * [Orlando](https://github.com/OrlandoCo) * [Tarrence van As](https://github.com/tarrencev) +* [Antoine Baché](https://github.com/Antonito) ### License MIT License - see [LICENSE](LICENSE) for full text diff --git a/session.go b/session.go index 95520f7..bb01f26 100644 --- a/session.go +++ b/session.go @@ -4,6 +4,7 @@ import ( "io" "net" "sync" + "sync/atomic" "github.com/pion/logging" "github.com/pion/transport/packetio" @@ -17,7 +18,8 @@ type streamSession interface { type session struct { localContextMutex sync.Mutex - localContext, remoteContext *Context + localContext *Context + remoteContext atomic.Value // *Context localOptions, remoteOptions []ContextOption newStream chan readStream @@ -106,17 +108,19 @@ func (s *session) close() error { } func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error { - var err error - s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...) + localContext, err := CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...) if err != nil { return err } - s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...) + remoteContext, err := CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...) if err != nil { return err } + s.localContext = localContext + s.remoteContext.Store(remoteContext) + go func() { defer func() { close(s.newStream) @@ -148,3 +152,23 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote return nil } + +// UpdateContext updates the local and remote context of the session. +func (s *session) UpdateContext(config *Config) error { + localContext, err := CreateContext(config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Profile, s.localOptions...) + if err != nil { + return err + } + remoteContext, err := CreateContext(config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, s.remoteOptions...) + if err != nil { + return err + } + + s.localContextMutex.Lock() + s.localContext = localContext + s.localContextMutex.Unlock() + + s.remoteContext.Store(remoteContext) + + return nil +} diff --git a/session_srtcp.go b/session_srtcp.go index a5fb656..2e6e168 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -147,7 +147,9 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 { } func (s *SessionSRTCP) decrypt(buf []byte) error { - decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) + // Safe since remoteContext always contains a *Context. + remoteContext := s.remoteContext.Load().(*Context) + decrypted, err := remoteContext.DecryptRTCP(buf, buf, nil) if err != nil { return err } diff --git a/session_srtp.go b/session_srtp.go index dc815af..99b255c 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -157,7 +157,9 @@ func (s *SessionSRTP) decrypt(buf []byte) error { return errFailedTypeAssertion } - decrypted, err := s.remoteContext.decryptRTP(buf, buf, h) + // Safe since remoteContext always contains a *Context. + remoteContext := s.remoteContext.Load().(*Context) + decrypted, err := remoteContext.decryptRTP(buf, buf, h) if err != nil { return err }