Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions lib/ch/connection.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ defmodule Ch.Connection do

@user_agent "ch/" <> Mix.Project.config()[:version]

@server_display_name_key :server_display_name

@typep conn :: HTTP.t()

@impl true
Expand All @@ -25,6 +27,12 @@ defmodule Ch.Connection do
|> maybe_put_private(:username, opts[:username])
|> maybe_put_private(:password, opts[:password])
|> maybe_put_private(:settings, opts[:settings])
|> HTTP.put_private(:reconnect_opts, %{
scheme: scheme,
address: address,
port: port,
mint_opts: mint_opts
})

handshake = Query.build("select 1, version()")
params = DBConnection.Query.encode(handshake, _params = [], _opts = [])
Expand Down Expand Up @@ -364,19 +372,23 @@ defmodule Ch.Connection do
| {:error, Error.t(), conn}
| {:disconnect, Mint.Types.error(), conn}
defp request(conn, method, path, headers, body, opts) do
with {:ok, conn, _ref} <- send_request(conn, method, path, headers, body) do
receive_full_response(conn, timeout(conn, opts))
end
with_retry_if_stale_connection(conn, fn conn ->
with {:ok, conn, _ref} <- send_request(conn, method, path, headers, body) do
receive_full_response(conn, timeout(conn, opts))
end
end)
end

@spec request_chunked(conn, binary, binary, Mint.Types.headers(), Enumerable.t(), Keyword.t()) ::
{:ok, conn, [response]}
| {:error, Error.t(), conn}
| {:disconnect, Mint.Types.error(), conn}
def request_chunked(conn, method, path, headers, stream, opts) do
with {:ok, conn, ref} <- send_request(conn, method, path, headers, :stream),
{:ok, conn} <- stream_body(conn, ref, stream),
do: receive_full_response(conn, timeout(conn, opts))
with_retry_if_stale_connection(conn, fn conn ->
with {:ok, conn, ref} <- send_request(conn, method, path, headers, :stream),
{:ok, conn} <- stream_body(conn, ref, stream),
do: receive_full_response(conn, timeout(conn, opts))
end)
end

@spec stream_body(conn, Mint.Types.request_ref(), Enumerable.t()) ::
Expand Down Expand Up @@ -405,6 +417,56 @@ defmodule Ch.Connection do
end
end

defp with_retry_if_stale_connection(conn, fun) do
case fun.(conn) do
{:disconnect, reason, conn} ->
if reconnectable_error?(reason) do
case reconnect(conn) do
{:ok, new_conn} ->
fun.(new_conn)

{:error, reason} ->
{:disconnect, reason, conn}
end
else
{:disconnect, reason, conn}
end

other ->
other
end
end

defp reconnectable_error?(%Mint.TransportError{reason: :closed}), do: true
defp reconnectable_error?(%Mint.TransportError{reason: :econnreset}), do: true
defp reconnectable_error?(_), do: false

@spec reconnect(conn) :: {:ok, conn} | {:error, Mint.Types.error()}
defp reconnect(conn) do
%{scheme: scheme, address: address, port: port, mint_opts: mint_opts} =
HTTP.get_private(conn, :reconnect_opts)

{:ok, _closed_conn} = HTTP.close(conn)

case HTTP.connect(scheme, address, port, mint_opts) do
{:ok, new_conn} ->
new_conn =
new_conn
|> HTTP.put_private(:timeout, HTTP.get_private(conn, :timeout))
|> maybe_put_private(:database, HTTP.get_private(conn, :database))
|> maybe_put_private(:username, HTTP.get_private(conn, :username))
|> maybe_put_private(:password, HTTP.get_private(conn, :password))
|> maybe_put_private(:settings, HTTP.get_private(conn, :settings))
|> HTTP.put_private(:reconnect_opts, HTTP.get_private(conn, :reconnect_opts))
|> maybe_put_private(@server_display_name_key, HTTP.get_private(conn, @server_display_name_key))

{:ok, new_conn}

{:error, _reason} = error ->
error
end
end

@spec receive_full_response(conn, timeout) ::
{:ok, conn, [response]}
| {:error, Error.t(), conn}
Expand Down Expand Up @@ -499,8 +561,6 @@ defmodule Ch.Connection do
"/?" <> URI.encode_query(settings ++ query_params)
end

@server_display_name_key :server_display_name

@spec ensure_same_server(conn, Mint.Types.headers()) :: conn
defp ensure_same_server(conn, headers) do
expected_name = HTTP.get_private(conn, @server_display_name_key)
Expand Down
Loading