From 2666b09abedcbcd9184f47e171da2919a4651f6b Mon Sep 17 00:00:00 2001
From: Jonas Heinle <jonas.heinle@ipa.fraunhofer.de>
Date: Wed, 17 Jul 2024 21:35:52 +0200
Subject: [PATCH] fix

---
 .gitignore                   |  2 ++
 tests/test_webdav_client.py  | 34 ++++++++++++++++++----------------
 webdavclient/webdavclient.py | 16 +++++++++++-----
 3 files changed, 31 insertions(+), 21 deletions(-)

diff --git a/.gitignore b/.gitignore
index efa407c..7fdc6fc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+local_data
+
 # Byte-compiled / optimized / DLL files
 __pycache__/
 *.py[cod]
diff --git a/tests/test_webdav_client.py b/tests/test_webdav_client.py
index bcb834e..dbb8abf 100644
--- a/tests/test_webdav_client.py
+++ b/tests/test_webdav_client.py
@@ -9,7 +9,7 @@
 
 
 def cleanup_test_files() -> None:
-    pass
+    os.removedirs("local_data")
 
 
 def wait_for_server_to_start(url, timeout=10) -> bool:
@@ -62,10 +62,21 @@ def webdav_client() -> WebDavClient:
     return WebDavClient(hostname, username, password)
 
 
+def test_filter_after_global_base_path(webdav_client) -> None:
+    path = "/data/subfolder1/text.txt"
+    path2 = "http://localhost:8081/data"
+    remote_base_path = "data"
+    result = webdav_client.filter_after_global_base_path(path, remote_base_path)
+    result2 = webdav_client.filter_after_global_base_path(path2, remote_base_path)
+    assert result == "subfolder1/text.txt"
+    # assert result2 ==
+
+
 def test_list_files(webdav_client) -> None:
-    url = os.path.join(webdav_client.hostname, "data")
+    remote_base_path = "data"
+    url = os.path.join(webdav_client.hostname, remote_base_path)
     url = url.replace(os.sep, "/")
-    files = webdav_client.list_files(url)
+    files = webdav_client.list_files(url, remote_base_path)
     assert "/data/Readme.md" in files
 
 
@@ -76,23 +87,16 @@ def test_list_folders(webdav_client):
     assert "subfolder3" in folders
 
 
-def test_filter_after_global_base_path(webdav_client) -> None:
-    path = "/data/subfolder1/text.txt"
-    remote_base_path = "data"
-    result = webdav_client.filter_after_global_base_path(path, remote_base_path)
-    assert result == "subfolder1/text.txt"
-
-
 def test_get_sub_path(webdav_client) -> None:
     full_path = "/data/subfolder1/text.txt"
-    initial_part = "/data"
+    initial_part = "data"
     result = webdav_client.get_sub_path(full_path, initial_part)
     assert result == "subfolder1/text.txt"
 
 
 def test_download_files(webdav_client) -> None:
 
-    global_remote_base_path = "http://localhost:8081"
+    global_remote_base_path = "data"
     remote_base_path = "data"
     local_base_path = "local_data"
 
@@ -108,10 +112,8 @@ def test_download_all_files_iterative(webdav_client):
     local_base_path = "local_data"
 
     webdav_client.download_all_files_iterative(remote_base_path, local_base_path)
-    assert os.path.exists(os.path.join(local_base_path, remote_base_path, "Readme.md"))
-    assert os.path.exists(
-        os.path.join(local_base_path, remote_base_path, "subfolder1/text.txt")
-    )
+    assert os.path.exists(os.path.join(local_base_path, "Readme.md"))
+    assert os.path.exists(os.path.join(local_base_path, "subfolder1/text.txt"))
 
 
 if __name__ == "__main__":
diff --git a/webdavclient/webdavclient.py b/webdavclient/webdavclient.py
index da2dbe3..2e078dd 100644
--- a/webdavclient/webdavclient.py
+++ b/webdavclient/webdavclient.py
@@ -57,7 +57,7 @@ def __init__(self, hostname: str, username: str, password: str) -> None:
         self.logger.addHandler(console_handler)
         self.logger.addHandler(file_handler)
 
-    def list_files(self, url: str) -> list[str]:
+    def list_files(self, url: str, remote_base_path: str) -> list[str]:
         """
         This method list all files from your WebDav host that stay under the
         url
@@ -224,6 +224,7 @@ def get_sub_path(self, full_path: str, initial_part: str) -> str:
             initial_part += "/"
 
         # Handle the edge case where the full path is exactly the initial part
+        initial_part = "/" + initial_part
         if full_path == initial_part.rstrip("/"):
             return ""
 
@@ -273,7 +274,7 @@ def download_files(
         url = os.path.join(self.hostname, remote_base_path)
         # as we communicate we do not want WINDWOS \ as os.sep!
         url = url.replace(os.sep, "/")
-        files_on_host = self.list_files(url)
+        files_on_host = self.list_files(url, global_remote_base_path)
 
         if len(files_on_host) == 0:
             self.logger.info("Found no files on remote_base_path: %s", remote_base_path)
@@ -292,9 +293,14 @@ def download_files(
             )
             remote_file_url = remote_file_url.replace(os.sep, "/")
             self.logger.info("The remote file url is: %s", remote_file_url)
-            sub_path = self.get_sub_path(remote_file_url, global_remote_base_path)
+            sub_path = self.get_sub_path(file_path, global_remote_base_path)
+
+            if sub_path == decoded_filename:
+                sub_path = ""
+
             self.logger.debug("The current sub path is: %s", sub_path)
             local_file_path = os.path.join(local_base_path, sub_path, decoded_filename)
+            local_file_path = local_file_path.replace(os.sep, "/")
             self.logger.debug(
                 "The current file that is stored has the full path: %s", local_file_path
             )
@@ -349,7 +355,7 @@ def download_all_files_iterative(
         # Initialize the stack with the root directory
         stack: list[str] = [remote_base_path]
 
-        # global_remote_base_path: str = remote_base_path
+        global_remote_base_path: str = remote_base_path
 
         while stack:
             current_remote_path: str = stack.pop()
@@ -357,7 +363,7 @@ def download_all_files_iterative(
 
             # Download files in the current directory
             self.download_files(
-                self.hostname,  # global_remote_base_path,
+                global_remote_base_path,
                 current_remote_path,
                 local_base_path,
             )