diff --git a/webdav/lock.go b/webdav/lock.go
index 344ac5cea..a1538d792 100644
--- a/webdav/lock.go
+++ b/webdav/lock.go
@@ -93,6 +93,19 @@ type LockSystem interface {
Unlock(now time.Time, token string) error
}
+// LockDeleter extends a LockSystem to support deleting locks by file. This is
+// useful when the LockSystem is decoupled from the FileSystem, particularly
+// because some FileSystem operations impact the LockSystem (e.g. deleting a
+// file).
+type LockDeleter interface {
+ // Delete removes all locks in the system rooted at name.
+ //
+ // A lock on the name should be confirmed prior to calling Delete.
+ // If Delete returns any non-nil error the Handler will write a "409
+ // Conflict" HTTP status
+ Delete(now time.Time, name string) error
+}
+
// LockDetails are a lock's metadata.
type LockDetails struct {
// Root is the root resource name being locked. For a zero-depth lock, the
@@ -288,6 +301,22 @@ func (m *memLS) Unlock(now time.Time, token string) error {
return nil
}
+func (m *memLS) Delete(now time.Time, name string) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.collectExpiredNodes(now)
+
+ n := m.byName[name]
+ if n == nil {
+ // No locks for this node. That's okay, since the goal of this
+ // function is to blindly clean things up.
+ return nil
+ }
+
+ m.remove(n)
+ return nil
+}
+
func (m *memLS) canCreate(name string, zeroDepth bool) bool {
return walkToRoot(name, func(name0 string, first bool) bool {
n := m.byName[name0]
diff --git a/webdav/webdav.go b/webdav/webdav.go
index d88995f9f..aa48d1290 100644
--- a/webdav/webdav.go
+++ b/webdav/webdav.go
@@ -169,6 +169,22 @@ func (h *Handler) confirmLocks(r *http.Request, src, dst string) (release func()
return nil, http.StatusPreconditionFailed, ErrLocked
}
+func (h *Handler) deleteLocks(reqPath string) (status int, err error) {
+ deleter, ok := h.LockSystem.(LockDeleter)
+ if !ok {
+ // Can't delete -- system doesn't support it. Assume it can handle this case
+ return 0, nil
+ }
+
+ err = deleter.Delete(time.Now(), reqPath)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+
+ return 0, nil
+
+}
+
func (h *Handler) handleOptions(w http.ResponseWriter, r *http.Request) (status int, err error) {
reqPath, status, err := h.stripPrefix(r.URL.Path)
if err != nil {
@@ -247,6 +263,10 @@ func (h *Handler) handleDelete(w http.ResponseWriter, r *http.Request) (status i
if err := h.FileSystem.RemoveAll(ctx, reqPath); err != nil {
return http.StatusMethodNotAllowed, err
}
+
+ if status, err := h.deleteLocks(reqPath); err != nil {
+ return status, err
+ }
return http.StatusNoContent, nil
}
@@ -386,7 +406,17 @@ func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request) (status
return http.StatusBadRequest, errInvalidDepth
}
}
- return moveFiles(ctx, h.FileSystem, src, dst, r.Header.Get("Overwrite") == "T")
+ status, err = moveFiles(ctx, h.FileSystem, src, dst, r.Header.Get("Overwrite") == "T")
+ if status < 200 || status > 300 {
+ return status, err
+ }
+
+ delStatus, err := h.deleteLocks(src)
+ if err != nil {
+ return delStatus, err
+ }
+
+ return status, err
}
func (h *Handler) handleLock(w http.ResponseWriter, r *http.Request) (retStatus int, retErr error) {
diff --git a/webdav/webdav_test.go b/webdav/webdav_test.go
index 2baebe3c9..f5b15d3a8 100644
--- a/webdav/webdav_test.go
+++ b/webdav/webdav_test.go
@@ -21,21 +21,21 @@ import (
"testing"
)
+// createLockBody comes from the example in Section 9.10.7.
+const createLockBody = `
+
+
+
+
+ http://example.org/~ejw/contact.html
+
+
+`
+
// TODO: add tests to check XML responses with the expected prefix path
func TestPrefix(t *testing.T) {
const dst, blah = "Destination", "blah blah blah"
- // createLockBody comes from the example in Section 9.10.7.
- const createLockBody = `
-
-
-
-
- http://example.org/~ejw/contact.html
-
-
- `
-
do := func(method, urlStr string, body string, wantStatusCode int, headers ...string) (http.Header, error) {
var bodyReader io.Reader
if body != "" {
@@ -347,3 +347,112 @@ func TestFilenameEscape(t *testing.T) {
}
}
}
+
+func TestDelete(t *testing.T) {
+ do := func(method, urlStr, body string, expectedCode int, headers ...string) (error, http.Header) {
+ var bodyReader io.Reader
+ if body != "" {
+ bodyReader = strings.NewReader(body)
+ }
+ req, err := http.NewRequest(method, urlStr, bodyReader)
+ if err != nil {
+ return err, nil
+ }
+
+ for len(headers) >= 2 {
+ req.Header.Add(headers[0], headers[1])
+ headers = headers[2:]
+ }
+
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return err, nil
+ }
+ defer res.Body.Close()
+
+ _, err = ioutil.ReadAll(res.Body)
+ if err != nil {
+ return err, res.Header
+ }
+ if res.StatusCode != expectedCode {
+ return fmt.Errorf("%q path=%q: got status %d, want %d", method, urlStr, res.StatusCode, expectedCode), nil
+
+ }
+ return nil, res.Header
+ }
+
+ testCases := []struct {
+ path string
+ lock, recreate bool
+ }{{
+ path: `/file`,
+ }, {
+ path: `/file`,
+ recreate: true,
+ }, {
+ // This reproduces https://github.com/golang/go/issues/42839
+ path: `/something_else`,
+ lock: true,
+ recreate: true,
+ }}
+
+ srv := httptest.NewServer(&Handler{
+ FileSystem: NewMemFS(),
+ LockSystem: NewMemLS(),
+ })
+ defer srv.Close()
+
+ u, err := url.Parse(srv.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Runs through the following logic:
+ // Create File
+ // If locking, lock file
+ // Delete file
+ // Check that file is gone
+ // If recreating, recreate file
+ for _, tc := range testCases {
+ u.Path = tc.path
+ err, _ := do("PUT", u.String(), "content", http.StatusCreated)
+ if err != nil {
+ t.Errorf("Initial create: %v", err)
+ continue
+ }
+
+ headers := []string{}
+ if tc.lock {
+ err, hdrs := do("LOCK", u.String(), createLockBody, http.StatusOK)
+ if err != nil {
+ t.Errorf("Lock: %v", err)
+ continue
+ }
+
+ lockToken := hdrs.Get("Lock-Token")
+ if lockToken != "" {
+ ifHeader := fmt.Sprintf("<%s%s> (%s)", srv.URL, tc.path, lockToken)
+ headers = append(headers, "If", ifHeader)
+ }
+ }
+
+ err, _ = do("DELETE", u.String(), "", http.StatusNoContent, headers...)
+ if err != nil {
+ t.Errorf("Delete: %v", err)
+ continue
+ }
+
+ err, _ = do("GET", u.String(), "", http.StatusNotFound)
+ if err != nil {
+ t.Errorf("Get: %v", err)
+ continue
+ }
+
+ if tc.recreate {
+ err, _ := do("PUT", u.String(), "content", http.StatusCreated)
+ if err != nil {
+ t.Errorf("Second create: %v", err)
+ continue
+ }
+ }
+ }
+}