Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webdav: only require locks when necessary #93

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 103 additions & 30 deletions webdav/webdav.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,45 +96,116 @@ func (h *Handler) lock(now time.Time, root string) (token string, status int, er
return token, 0, nil
}

func (h *Handler) speculativeLock(now time.Time, root, ifHdr string) (token string, status int, err error) {
token, status, err = h.lock(now, root)

// If we succeed or fail for any reason other than ErrLocked, short-circuit out
if err != ErrLocked {
return
}

// If we have an If header, then while we may have failed to take an anonymous lock,
// we have an opportunity to confirm the lock using the presented credentials.
if ifHdr != "" {
err = nil
status = 0
}

return
}

func (h *Handler) confirmLocks(r *http.Request, src, dst string) (release func(), status int, err error) {
// The general strategy for lock confirmation is as follows:
// 1. Speculatively take locks for src/dst in case no locks are held on them.
// 2. If any locks *are* held on src/dst, check that the If header satisfies them.
// 3. As part of checking #2, validate any other constraints implied in the If header.

hdr := r.Header.Get("If")
if hdr == "" {
// An empty If header means that the client hasn't previously created locks.
// Even if this client doesn't care about locks, we still need to check that
// the resources aren't locked by another client, so we create temporary
// locks that would conflict with another client's locks. These temporary
// locks are unlocked at the end of the HTTP request.
now, srcToken, dstToken := time.Now(), "", ""
if src != "" {
srcToken, status, err = h.lock(now, src)
if err != nil {
return nil, status, err
}
ih := ifHeader{}
// Parse the If header first, since if it's bad we should just fail out, even if we
// don't need it.
if hdr != "" {
var ok bool
ih, ok = parseIfHeader(hdr)
if !ok {
return nil, http.StatusBadRequest, errInvalidIfHeader
}
if dst != "" {
dstToken, status, err = h.lock(now, dst)
if err != nil {
if srcToken != "" {
h.LockSystem.Unlock(now, srcToken)
}
return nil, status, err
}
}

// Even if the client hasn't previously created locks, another principle may have.
// So, we still need to check that the resources aren't locked by another client.
// To do so, we create temporary locks on the requested resources. If the resoure is
// locked, the operation will fail. In that case, we will check if the request
// presented a lock for that particular resource.

// Any temporary locks are removed at the end of the request.
now, srcToken, dstToken := time.Now(), "", ""
if src != "" {
srcToken, status, err = h.speculativeLock(now, src, hdr)
if err != nil {
return nil, status, err
}
}

return func() {
if dstToken != "" {
h.LockSystem.Unlock(now, dstToken)
}
if dst != "" {
dstToken, status, err = h.speculativeLock(now, dst, hdr)
if err != nil {
if srcToken != "" {
h.LockSystem.Unlock(now, srcToken)
}
}, 0, nil
return nil, status, err
}
}

speculativeLockRelease := func() {
if dstToken != "" {
h.LockSystem.Unlock(now, dstToken)
}
if srcToken != "" {
h.LockSystem.Unlock(now, srcToken)
}
}

// Exclude src/dst from the lock search if we already have a lock for them.
if srcToken != "" {
src = ""
}

if dstToken != "" {
dst = ""
}

if src == "" && dst == "" {

if len(ih.lists) == 0 {
// No conditions to evaluate and no src/dst constraints to check.
// Everything with the request is good. Return success.
return speculativeLockRelease, 0, nil
}

// In this case, we have created temporary locks on any resource we care about.
// This means that none of the conditions can evaluate to true. Fall through with
// an empty list to ensure consistent error handling
ih.lists = []ifList{}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of this -- it feels a little messy to depend on the behaviour of atLeastOnIfListPasses, but I think the newly created functions name + the comment make it clear enough what's expected.

}

ih, ok := parseIfHeader(hdr)
if !ok {
return nil, http.StatusBadRequest, errInvalidIfHeader
lockRelease, status, err := h.atLeastOneIfListPasses(src, dst, r.Host, ih)

if err != nil {
speculativeLockRelease()
return nil, status, err
}

return func() {
// Release both the locks we just confirmed, and any we speculatively created.
lockRelease()
speculativeLockRelease()
}, 0, nil
}

func (h *Handler) atLeastOneIfListPasses(src, dst string, host string, ih ifHeader) (release func(), status int, err error) {

// Run the list of provided lock tokens agains the resources we want to lock.
// ih is a disjunction (OR) of ifLists, so any ifList will do.
for _, l := range ih.lists {
lsrc := l.resourceTag
Expand All @@ -145,23 +216,25 @@ func (h *Handler) confirmLocks(r *http.Request, src, dst string) (release func()
if err != nil {
continue
}
if u.Host != r.Host {
if u.Host != host {
continue
}
lsrc, status, err = h.stripPrefix(u.Path)
if err != nil {
return nil, status, err
}
}
release, err = h.LockSystem.Confirm(time.Now(), lsrc, dst, l.conditions...)
release, err := h.LockSystem.Confirm(time.Now(), lsrc, dst, l.conditions...)
if err == ErrConfirmationFailed {
continue
}
if err != nil {
return nil, http.StatusInternalServerError, err
}

return release, 0, nil
}

// Section 10.4.1 says that "If this header is evaluated and all state lists
// fail, then the request must fail with a 412 (Precondition Failed) status."
// We follow the spec even though the cond_put_corrupt_token test case from
Expand Down
96 changes: 90 additions & 6 deletions webdav/webdav_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@ import (
"testing"
)

// TODO: add tests to check XML responses with the expected prefix path
func TestPrefix(t *testing.T) {
const dst, blah = "Destination", "blah blah blah"

// createLockBody comes from the example in Section 9.10.7.
const createLockBody = `<?xml version="1.0" encoding="utf-8" ?>
// createLockBody comes from the example in Section 9.10.7.
const createLockBody = `<?xml version="1.0" encoding="utf-8" ?>
<D:lockinfo xmlns:D='DAV:'>
<D:lockscope><D:exclusive/></D:lockscope>
<D:locktype><D:write/></D:locktype>
Expand All @@ -36,6 +32,10 @@ func TestPrefix(t *testing.T) {
</D:lockinfo>
`

// TODO: add tests to check XML responses with the expected prefix path
func TestPrefix(t *testing.T) {
const dst, blah = "Destination", "blah blah blah"

do := func(method, urlStr string, body string, wantStatusCode int, headers ...string) (http.Header, error) {
var bodyReader io.Reader
if body != "" {
Expand Down Expand Up @@ -185,6 +185,29 @@ func TestPrefix(t *testing.T) {
continue
}

wantI := map[string]int{
"/": http.StatusLocked,
"/a/": http.StatusLocked,
"/a/b/": http.StatusLocked,
"/a/b/c/": http.StatusNotFound,
}[prefix]
if _, err := do("PUT", srv.URL+"/a/b/e/g", blah, wantI); err != nil {
t.Errorf("prefix=%-9q PUT /a/b/e/g: %v", prefix, err)
continue
}

badIfHeader := fmt.Sprintf("<%s/a/b/e/g> (foobar)", srv.URL)
wantJ := map[string]int{
"/": http.StatusPreconditionFailed,
"/a/": http.StatusPreconditionFailed,
"/a/b/": http.StatusPreconditionFailed,
"/a/b/c/": http.StatusNotFound,
}[prefix]
if _, err := do("PUT", srv.URL+"/a/b/e/g", blah, wantJ, "If", badIfHeader); err != nil {
t.Errorf("prefix=%-9q PUT /a/b/e/g: %v", prefix, err)
continue
}

got, err := find(ctx, nil, fs, "/")
if err != nil {
t.Errorf("prefix=%-9q find: %v", prefix, err)
Expand Down Expand Up @@ -347,3 +370,64 @@ func TestFilenameEscape(t *testing.T) {
}
}
}

func TestMoveLockedSrcUnlockedDst(t *testing.T) {
// This test reproduces https://github.com/golang/go/issues/43556
do := func(method, urlStr string, body string, wantStatusCode int, headers ...string) (http.Header, error) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these tests could do with some refactoring. I tend to favour writing simple test harnesses that handle much of the common behaviour between tests. In this case, I think we could refactor the 'do' function as well as the server setup. That would simplify both the existing tests, and any new ones which use a similar pattern. I was going to do it in this PR, but it probably makes more sense to do it in a separate one specific to that refactoring. Let me know if you're interested.

var bodyReader io.Reader
if body != "" {
bodyReader = strings.NewReader(body)
}
req, err := http.NewRequest(method, urlStr, bodyReader)
if err != nil {
return nil, err
}
for len(headers) >= 2 {
req.Header.Add(headers[0], headers[1])
headers = headers[2:]
}
res, err := http.DefaultTransport.RoundTrip(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != wantStatusCode {
return nil, fmt.Errorf("got status code %d, want %d", res.StatusCode, wantStatusCode)
}
return res.Header, nil
}

srv := httptest.NewServer(&Handler{
FileSystem: NewMemFS(),
LockSystem: NewMemLS(),
})
defer srv.Close()

src, err := url.Parse(srv.URL)
if err != nil {
t.Fatal(err)
}

src.Path = "/locked_src"
hdrs, err := do("LOCK", src.String(), createLockBody, http.StatusCreated)
if err != nil {
t.Fatal(err)
}

lockToken := hdrs.Get("Lock-Token")
if lockToken == "" {
t.Errorf("Expected lock token")
}

if err != nil {
t.Fatal(err)
}
ifHeader := fmt.Sprintf("<%s%s> (%s)", srv.URL, src.Path, lockToken)
headers := []string{"If", ifHeader, "Destination", "/unlocked_path"}

_, err = do("MOVE", src.String(), "", http.StatusCreated, headers...)

if err != nil {
t.Fatal(err)
}
}