Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing custom dns_resolver to client.connect #120

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 233 additions & 9 deletions http/client.lua
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
local monotime = require "cqueues".monotime
local ca = require "cqueues.auxlib"
local ce = require "cqueues.errno"
local cs = require "cqueues.socket"
local cqueues_dns = require "cqueues.dns"
local cqueues_dns_record = require "cqueues.dns.record"
local http_tls = require "http.tls"
local http_util = require "http.util"
local connection_common = require "http.connection_common"
local onerror = connection_common.onerror
local new_h1_connection = require "http.h1_connection".new
local new_h2_connection = require "http.h2_connection".new
local lpeg = require "lpeg"
local IPv4_patts = require "lpeg_patterns.IPv4"
local IPv6_patts = require "lpeg_patterns.IPv6"
local openssl_ssl = require "openssl.ssl"
local openssl_ctx = require "openssl.ssl.context"
local openssl_verify_param = require "openssl.x509.verify_param"

local AF_UNSPEC = cs.AF_UNSPEC
local AF_UNIX = cs.AF_UNIX
local AF_INET = cs.AF_INET
local AF_INET6 = cs.AF_INET6

local DNS_SECTION_ANSWER = cqueues_dns_record.ANSWER
local DNS_CLASS_IN = cqueues_dns_record.IN
local DNS_TYPE_A = cqueues_dns_record.A
local DNS_TYPE_AAAA = cqueues_dns_record.AAAA
local DNS_TYPE_CNAME = cqueues_dns_record.CNAME

local EOF = lpeg.P(-1)
local IPv4address = IPv4_patts.IPv4address * EOF
local IPv6addrz = IPv6_patts.IPv6addrz * EOF

-- Create a shared 'default' TLS context
local default_ctx = http_tls.new_client_context()

Expand Down Expand Up @@ -80,7 +102,177 @@ local function negotiate(s, options, timeout)
end
end

-- `type` parameter is what sort of records you want to find could be "A" or
-- "AAAA" or `nil` if you want to filter yourself e.g. to implement
-- https://www.ietf.org/archive/id/draft-vavrusa-dnsop-aaaa-for-free-00.txt
local function each_matching_record(pkt, name, type)
-- First need to do CNAME chasing
local params = {
section = DNS_SECTION_ANSWER;
class = DNS_CLASS_IN;
type = DNS_TYPE_CNAME;
name = name .. ".";
}
for _=1, 8 do -- avoid cname loops
-- Ignores any CNAME record past the first (which should never occur anyway)
local func, state, first = pkt:grep(params)
local record = func(state, first)
if record == nil then
-- Not found
break
end
params.name = record:host()
end
params.type = type
return pkt:grep(params)
end

local function dns_lookup(records, dns_resolver, host, port, query_type, filter_type, timeout)
local packet = dns_resolver:query(host, query_type, nil, timeout)
if not packet then
return
end
for rec in each_matching_record(packet, host, filter_type) do
local t = rec:type()
if t == DNS_TYPE_AAAA then
records:add_v6(rec:addr(), port)
elseif t == DNS_TYPE_A then
records:add_v4(rec:addr(), port)
end
end
end

local records_methods = {}
local records_mt = {
__name = "http.client.records";
__index = records_methods;
}

local function new_records()
return setmetatable({
n = 0;
nil -- preallocate space for one
}, records_mt)
end

function records_mt:__len()
return self.n
end

local record_ipv4_methods = {
family = AF_INET;
}
local record_ipv4_mt = {
__name = "http.client.record.ipv4";
__index = record_ipv4_methods;
}
function records_methods:add_v4(addr, port)
local n = self.n + 1
self[n] = setmetatable({ addr = addr, port = port }, record_ipv4_mt)
self.n = n
end

local record_ipv6_methods = {
family = AF_INET6;
}
local record_ipv6_mt = {
__name = "http.client.record.ipv6";
__index = record_ipv6_methods;
}
function records_methods:add_v6(addr, port)
if type(addr) == "string" then
-- Normalise
addr = assert(IPv6addrz:match(addr))
elseif getmetatable(addr) ~= IPv6_patts.IPv6_mt then
error("invalid argument")
end
addr = tostring(addr)
local n = self.n + 1
self[n] = setmetatable({ addr = addr, port = port }, record_ipv6_mt)
self.n = n
end

local record_unix_methods = {
family = AF_UNIX;
}
local record_unix_mt = {
__name = "http.client.record.unix";
__index = record_unix_methods;
}
function records_methods:add_unix(path)
local n = self.n + 1
self[n] = setmetatable({ path = path }, record_unix_mt)
self.n = n
end

function records_methods:remove_family(family)
if family == nil then
family = AF_UNSPEC
end

for i=self.n, 1, -1 do
if self[i].family == family then
table.remove(self, i)
self.n = self.n - 1
end
end
end

local function lookup_records(options, timeout)
local family = options.family
if family == nil then
family = AF_UNSPEC
end

local records = new_records()

local path = options.path
if path then
if family ~= AF_UNSPEC and family ~= AF_UNIX then
error("cannot use .path with non-unix address family")
end
records:add_unix(path)
return records
end

local host = options.host
local port = options.port

local ipv4 = IPv4address:match(host)
if ipv4 then
if family == AF_UNSPEC or family == AF_INET then
records:add_v4(host, port)
end
return records
end

local ipv6 = IPv6addrz:match(host)
if ipv6 then
if family == AF_UNSPEC or family == AF_INET6 then
records:add_v6(ipv6, port)
end
return records
end

local dns_resolver = options.dns_resolver or cqueues_dns.getpool()
if family == AF_UNSPEC then
local deadline = timeout and monotime()+timeout
dns_lookup(records, dns_resolver, host, port, DNS_TYPE_AAAA, nil, timeout)
dns_lookup(records, dns_resolver, host, port, DNS_TYPE_A, nil, deadline and deadline-monotime())
elseif family == AF_INET then
dns_lookup(records, dns_resolver, host, port, DNS_TYPE_A, DNS_TYPE_A, timeout)
elseif family == AF_INET6 then
dns_lookup(records, dns_resolver, host, port, DNS_TYPE_AAAA, DNS_TYPE_AAAA, timeout)
end

return records
end

local function connect(options, timeout)
local deadline = timeout and monotime()+timeout

local records = lookup_records(options, timeout)

local bind = options.bind
if bind ~= nil then
assert(type(bind) == "string")
Expand All @@ -99,20 +291,52 @@ local function connect(options, timeout)
port = bind_port;
}
end
local s, err, errno = ca.fileresult(cs.connect {
family = options.family;
host = options.host;
port = options.port;
path = options.path;

local connect_params = {
family = nil;
host = nil;
port = nil;
path = nil;
bind = bind;
sendname = false;
v6only = options.v6only;
nodelay = true;
})
if s == nil then
return nil, err, errno
}

local lasterr, lasterrno = "The name does not resolve for the supplied parameters"
local i = 1
while i <= records.n do
local rec = records[i]
connect_params.family = rec.family;
connect_params.host = rec.addr;
connect_params.port = rec.port;
connect_params.path = rec.path;
local s
s, lasterr, lasterrno = ca.fileresult(cs.connect(connect_params))
if s then
local c
c, lasterr, lasterrno = negotiate(s, options, deadline and deadline-monotime())
if c then
-- Force TCP connect to occur
local ok
ok, lasterr, lasterrno = c:connect(deadline and deadline-monotime())
if ok then
return c
end
c:close()
else
s:close()
end
end
if lasterrno == ce.EAFNOSUPPORT then
-- If an address family is not supported then entirely remove that
-- family from candidate records
records:remove_family(connect_params.family)
else
i = i + 1
end
end
return negotiate(s, options, timeout)
return nil, lasterr, lasterrno
end

return {
Expand Down
79 changes: 70 additions & 9 deletions spec/client_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,95 @@ describe("http.client module", function()
local http_h1_connection = require "http.h1_connection"
local http_h2_connection = require "http.h2_connection"
local http_headers = require "http.headers"
local http_server = require "http.server"
local http_tls = require "http.tls"
local cqueues = require "cqueues"
local ca = require "cqueues.auxlib"
local cs = require "cqueues.socket"
local cdh = require "cqueues.dns.hosts"
local cdr = require "cqueues.dns.resolver"
local cdrs = require "cqueues.dns.resolvers"
local openssl_pkey = require "openssl.pkey"
local openssl_ctx = require "openssl.ssl.context"
local openssl_x509 = require "openssl.x509"
it("throws error on invalid family+path combination", function()
assert.has.errors(function()
client.connect{family = cs.AF_INET, path = "/somepath"}
end)
end)
it("invalid network parameters return nil, err, errno", function()
-- Invalid network parameters will return nil, err, errno
local ok, err, errno = client.connect{host="127.0.0.1", port="invalid"}
assert.same(nil, ok)
assert.same("string", type(err))
assert.same("number", type(errno))
end)
local function send_request(conn)
local stream = conn:new_stream()
local req_headers = http_headers.new()
req_headers:append(":authority", "myauthority")
req_headers:append(":method", "GET")
req_headers:append(":path", "/")
req_headers:append(":scheme", conn:checktls() and "https" or "http")
assert(stream:write_headers(req_headers, true))
local res_headers = assert(stream:get_headers())
assert.same("200", res_headers:get(":status"))
end
local function test(client_cb)
local cq = cqueues.new()
local s = assert(http_server.listen {
host = "localhost";
port = 0;
onstream = function(s, stream)
assert(stream:get_headers())
local resp_headers = http_headers.new()
resp_headers:append(":status", "200")
assert(stream:write_headers(resp_headers, false))
assert(stream:write_chunk("hello world", true))
stream:shutdown()
stream.connection:shutdown()
s:close()
end;
})
assert(s:listen())
local family, host, port = s:localname()
cq:wrap(function()
assert_loop(s)
end)
cq:wrap(client_cb, family, host, port)
assert_loop(cq, TEST_TIMEOUT)
assert.truthy(cq:empty())
end
it("works with a cqueues.dns.resolver object", function()
test(function(family, ip, port)
local hosts = cdh.new()
hosts:insert(ip, "example.com")
send_request(assert(client.connect {
dns_resolver = cdr.new(nil, hosts);
family = family;
host = "example.com";
port = port;
}))
end)
end)
it("works with a cqueues.dns.resolvers object", function()
test(function(family, ip, port)
local hosts = cdh.new()
hosts:insert(ip, "example.com")
send_request(assert(client.connect {
dns_resolver = cdrs.new(nil, hosts);
family = family;
host = "example.com";
port = port;
}))
end)
end)
local function test_pair(client_options, server_func)
local s, c = ca.assert(cs.pair())
local cq = cqueues.new();
cq:wrap(function()
local conn = assert(client.negotiate(c, client_options))
local stream = conn:new_stream()
local req_headers = http_headers.new()
req_headers:append(":authority", "myauthority")
req_headers:append(":method", "GET")
req_headers:append(":path", "/")
req_headers:append(":scheme", client_options.tls and "https" or "http")
assert(stream:write_headers(req_headers, true))
local res_headers = assert(stream:get_headers())
assert.same("200", res_headers:get(":status"))
send_request(conn)
end)
cq:wrap(function()
s = server_func(s)
Expand Down