diff --git a/src/lib.rs b/src/lib.rs index e7b672b..c083cf6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,6 +75,14 @@ impl UrlPy { self.inner.scheme() } + #[getter] + fn host(&self) -> Option { + let host = self.inner.host()?; + Some(HostPy { + inner: host.to_owned(), + }) + } + #[getter] fn username(&self) -> &str { self.inner.username() @@ -122,10 +130,35 @@ impl UrlPy { } } +#[repr(transparent)] +#[pyclass(name = "Domain", module = "url", frozen)] +struct HostPy { + inner: url::Host, +} + +#[pymethods] +impl HostPy { + #[new] + fn new(value: String) -> Self { + Self { + inner: url::Host::Domain(value), + } + } + + fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject { + match op { + CompareOp::Eq => (self.inner == other.inner).into_py(py), + CompareOp::Ne => (self.inner != other.inner).into_py(py), + _ => py.NotImplemented(), + } + } +} + #[pymodule] #[pyo3(name = "url")] fn url_py(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; m.add("URLError", py.get_type::())?; m.add("EmptyHost", py.get_type::())?; diff --git a/tests/test_url.py b/tests/test_url.py index 591b0d6..45dfa27 100644 --- a/tests/test_url.py +++ b/tests/test_url.py @@ -4,7 +4,8 @@ import pytest -from url import URL, InvalidIPv6Address, RelativeURLWithoutBase, URLError +from url import URL +import url def test_https_url(): @@ -16,9 +17,7 @@ def test_https_url(): assert issue_url.password is None assert issue_url.host_str == "github.com" - # TODO: Decide what API makes sense here in Python -- - # get back an IP address and otherwise error, or some other API - # assert (issue_url.host() == Some(Host::Domain("github.com"))) + assert issue_url.host == url.Domain("github.com") assert issue_url.port is None assert issue_url.path == "/rust-lang/rust/issues" @@ -59,17 +58,17 @@ def test_slash(): def test_invalid_ipv6_address(): - with pytest.raises(InvalidIPv6Address): + with pytest.raises(url.InvalidIPv6Address): URL.parse("http://[:::1]") def test_invalid_relative_url_without_base(): - with pytest.raises(RelativeURLWithoutBase): + with pytest.raises(url.RelativeURLWithoutBase): URL.parse("../main.css") def test_invalid_junk(): - with pytest.raises(URLError): + with pytest.raises(url.URLError): URL.parse("https:/12949a;df;;@@@")