diff --git a/webdav/lock.go b/webdav/lock.go index 344ac5cea..a1538d792 100644 --- a/webdav/lock.go +++ b/webdav/lock.go @@ -93,6 +93,19 @@ type LockSystem interface { Unlock(now time.Time, token string) error } +// LockDeleter extends a LockSystem to support deleting locks by file. This is +// useful when the LockSystem is decoupled from the FileSystem, particularly +// because some FileSystem operations impact the LockSystem (e.g. deleting a +// file). +type LockDeleter interface { + // Delete removes all locks in the system rooted at name. + // + // A lock on the name should be confirmed prior to calling Delete. + // If Delete returns any non-nil error the Handler will write a "409 + // Conflict" HTTP status + Delete(now time.Time, name string) error +} + // LockDetails are a lock's metadata. type LockDetails struct { // Root is the root resource name being locked. For a zero-depth lock, the @@ -288,6 +301,22 @@ func (m *memLS) Unlock(now time.Time, token string) error { return nil } +func (m *memLS) Delete(now time.Time, name string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + + n := m.byName[name] + if n == nil { + // No locks for this node. That's okay, since the goal of this + // function is to blindly clean things up. + return nil + } + + m.remove(n) + return nil +} + func (m *memLS) canCreate(name string, zeroDepth bool) bool { return walkToRoot(name, func(name0 string, first bool) bool { n := m.byName[name0] diff --git a/webdav/webdav.go b/webdav/webdav.go index d88995f9f..aa48d1290 100644 --- a/webdav/webdav.go +++ b/webdav/webdav.go @@ -169,6 +169,22 @@ func (h *Handler) confirmLocks(r *http.Request, src, dst string) (release func() return nil, http.StatusPreconditionFailed, ErrLocked } +func (h *Handler) deleteLocks(reqPath string) (status int, err error) { + deleter, ok := h.LockSystem.(LockDeleter) + if !ok { + // Can't delete -- system doesn't support it. Assume it can handle this case + return 0, nil + } + + err = deleter.Delete(time.Now(), reqPath) + if err != nil { + return http.StatusInternalServerError, err + } + + return 0, nil + +} + func (h *Handler) handleOptions(w http.ResponseWriter, r *http.Request) (status int, err error) { reqPath, status, err := h.stripPrefix(r.URL.Path) if err != nil { @@ -247,6 +263,10 @@ func (h *Handler) handleDelete(w http.ResponseWriter, r *http.Request) (status i if err := h.FileSystem.RemoveAll(ctx, reqPath); err != nil { return http.StatusMethodNotAllowed, err } + + if status, err := h.deleteLocks(reqPath); err != nil { + return status, err + } return http.StatusNoContent, nil } @@ -386,7 +406,17 @@ func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request) (status return http.StatusBadRequest, errInvalidDepth } } - return moveFiles(ctx, h.FileSystem, src, dst, r.Header.Get("Overwrite") == "T") + status, err = moveFiles(ctx, h.FileSystem, src, dst, r.Header.Get("Overwrite") == "T") + if status < 200 || status > 300 { + return status, err + } + + delStatus, err := h.deleteLocks(src) + if err != nil { + return delStatus, err + } + + return status, err } func (h *Handler) handleLock(w http.ResponseWriter, r *http.Request) (retStatus int, retErr error) { diff --git a/webdav/webdav_test.go b/webdav/webdav_test.go index 2baebe3c9..f5b15d3a8 100644 --- a/webdav/webdav_test.go +++ b/webdav/webdav_test.go @@ -21,21 +21,21 @@ import ( "testing" ) +// createLockBody comes from the example in Section 9.10.7. +const createLockBody = ` + + + + + http://example.org/~ejw/contact.html + + +` + // TODO: add tests to check XML responses with the expected prefix path func TestPrefix(t *testing.T) { const dst, blah = "Destination", "blah blah blah" - // createLockBody comes from the example in Section 9.10.7. - const createLockBody = ` - - - - - http://example.org/~ejw/contact.html - - - ` - do := func(method, urlStr string, body string, wantStatusCode int, headers ...string) (http.Header, error) { var bodyReader io.Reader if body != "" { @@ -347,3 +347,112 @@ func TestFilenameEscape(t *testing.T) { } } } + +func TestDelete(t *testing.T) { + do := func(method, urlStr, body string, expectedCode int, headers ...string) (error, http.Header) { + var bodyReader io.Reader + if body != "" { + bodyReader = strings.NewReader(body) + } + req, err := http.NewRequest(method, urlStr, bodyReader) + if err != nil { + return err, nil + } + + for len(headers) >= 2 { + req.Header.Add(headers[0], headers[1]) + headers = headers[2:] + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return err, nil + } + defer res.Body.Close() + + _, err = ioutil.ReadAll(res.Body) + if err != nil { + return err, res.Header + } + if res.StatusCode != expectedCode { + return fmt.Errorf("%q path=%q: got status %d, want %d", method, urlStr, res.StatusCode, expectedCode), nil + + } + return nil, res.Header + } + + testCases := []struct { + path string + lock, recreate bool + }{{ + path: `/file`, + }, { + path: `/file`, + recreate: true, + }, { + // This reproduces https://github.com/golang/go/issues/42839 + path: `/something_else`, + lock: true, + recreate: true, + }} + + srv := httptest.NewServer(&Handler{ + FileSystem: NewMemFS(), + LockSystem: NewMemLS(), + }) + defer srv.Close() + + u, err := url.Parse(srv.URL) + if err != nil { + t.Fatal(err) + } + // Runs through the following logic: + // Create File + // If locking, lock file + // Delete file + // Check that file is gone + // If recreating, recreate file + for _, tc := range testCases { + u.Path = tc.path + err, _ := do("PUT", u.String(), "content", http.StatusCreated) + if err != nil { + t.Errorf("Initial create: %v", err) + continue + } + + headers := []string{} + if tc.lock { + err, hdrs := do("LOCK", u.String(), createLockBody, http.StatusOK) + if err != nil { + t.Errorf("Lock: %v", err) + continue + } + + lockToken := hdrs.Get("Lock-Token") + if lockToken != "" { + ifHeader := fmt.Sprintf("<%s%s> (%s)", srv.URL, tc.path, lockToken) + headers = append(headers, "If", ifHeader) + } + } + + err, _ = do("DELETE", u.String(), "", http.StatusNoContent, headers...) + if err != nil { + t.Errorf("Delete: %v", err) + continue + } + + err, _ = do("GET", u.String(), "", http.StatusNotFound) + if err != nil { + t.Errorf("Get: %v", err) + continue + } + + if tc.recreate { + err, _ := do("PUT", u.String(), "content", http.StatusCreated) + if err != nil { + t.Errorf("Second create: %v", err) + continue + } + } + } +}