diff --git a/changelog.d/mastodon-websocket-protocol.fix b/changelog.d/mastodon-websocket-protocol.fix new file mode 100644 index 000000000..66dab16ed --- /dev/null +++ b/changelog.d/mastodon-websocket-protocol.fix @@ -0,0 +1 @@ +Echo Mastodon-style `Sec-WebSocket-Protocol` tokens in streaming WebSocket handshakes. diff --git a/lib/pleroma/web/endpoint.ex b/lib/pleroma/web/endpoint.ex index 36e78e1e2..07b33f866 100644 --- a/lib/pleroma/web/endpoint.ex +++ b/lib/pleroma/web/endpoint.ex @@ -9,8 +9,8 @@ defmodule Pleroma.Web.Endpoint do alias Pleroma.Config - socket("/api/v1/streaming", Pleroma.Web.MastodonAPI.WebsocketHandler, - longpoll: false, + plug(Pleroma.Web.MastodonAPI.WebsocketPlug, + path: "/api/v1/streaming", websocket: [ path: "/", compress: false, @@ -169,8 +169,7 @@ defmodule Pleroma.Web.Endpoint do else: "pleroma_key" extra = - Config.get([__MODULE__, :extra_cookie_attrs]) - |> Enum.join(";") + Enum.join(Config.get([__MODULE__, :extra_cookie_attrs]), ";") # The session will be stored in the cookie and signed, # this means its contents can be read but not tampered with. diff --git a/lib/pleroma/web/mastodon_api/websocket_handler.ex b/lib/pleroma/web/mastodon_api/websocket_handler.ex index 2b698bd5d..3dc862a5a 100644 --- a/lib/pleroma/web/mastodon_api/websocket_handler.ex +++ b/lib/pleroma/web/mastodon_api/websocket_handler.ex @@ -67,9 +67,10 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do @impl Phoenix.Socket.Transport def handle_in({text, [opcode: :text]}, state) do - with {:ok, %{} = event} <- Jason.decode(text) do - handle_client_event(event, state) - else + case Jason.decode(text) do + {:ok, %{} = event} -> + handle_client_event(event, state) + _ -> Logger.error("#{__MODULE__} received non-JSON event: #{inspect(text)}") {:ok, state} @@ -85,11 +86,11 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do def handle_info({:render_with_user, view, template, item, topic}, state) do user = %User{} = User.get_cached_by_ap_id(state.user.ap_id) - unless Streamer.filtered_by_user?(user, item) do + if Streamer.filtered_by_user?(user, item) do + {:ok, state} + else message = view.render(template, item, user, topic) {:push, {:text, message}, %{state | user: user}} - else - {:ok, state} end end @@ -253,7 +254,7 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do defp find_sec_websocket_protocol(sec_headers) do Enum.find_value(sec_headers, fn - {"sec-websocket-protocol", token} -> token + {"sec-websocket-protocol", protocols} -> protocols |> Plug.Conn.Utils.list() |> List.first() _ -> nil end) end diff --git a/lib/pleroma/web/mastodon_api/websocket_plug.ex b/lib/pleroma/web/mastodon_api/websocket_plug.ex new file mode 100644 index 000000000..58ee913b4 --- /dev/null +++ b/lib/pleroma/web/mastodon_api/websocket_plug.ex @@ -0,0 +1,105 @@ +# Pleroma: A lightweight social networking server +# Copyright © 2017-2022 Pleroma Authors +# SPDX-License-Identifier: AGPL-3.0-only + +defmodule Pleroma.Web.MastodonAPI.WebsocketPlug do + @moduledoc """ + A Phoenix 1.8 compatible WebSocket transport for Mastodon streaming. + + It mirrors Phoenix.Transports.WebSocket, but echoes a successfully authenticated + Mastodon-style Sec-WebSocket-Protocol token so browser clients accept the handshake. + """ + + @behaviour Plug + + import Plug.Conn + + alias Phoenix.Socket.Transport + alias Pleroma.Web.Endpoint + alias Pleroma.Web.MastodonAPI.WebsocketHandler + + @connect_info_opts [:check_csrf] + + @impl Plug + def init(opts) do + path = String.split(Keyword.fetch!(opts, :path), "/", trim: true) + websocket = Keyword.fetch!(opts, :websocket) + config = Transport.load_config(websocket, Phoenix.Transports.WebSocket) + + {path, config} + end + + @impl Plug + def call(%{method: "GET", path_info: path} = conn, {path, opts}) do + conn + |> fetch_query_params() + |> Transport.code_reload(Endpoint, opts) + |> Transport.transport_log(opts[:transport_log]) + |> Transport.check_origin(WebsocketHandler, Endpoint, opts) + |> connect(opts) + end + + def call(%{path_info: path} = conn, {path, _opts}) do + conn + |> send_resp(400, "") + |> halt() + end + + def call(conn, _opts), do: conn + + defp connect(%{halted: true} = conn, _opts), do: conn + + defp connect(%{params: params} = conn, opts) do + keys = Keyword.get(opts, :connect_info, []) + + connect_info = + Transport.connect_info(conn, Endpoint, keys, Keyword.take(opts, @connect_info_opts)) + + config = %{ + endpoint: Endpoint, + transport: :websocket, + options: opts, + params: params, + connect_info: connect_info + } + + case WebsocketHandler.connect(config) do + {:ok, arg} -> + try do + conn + |> echo_sec_websocket_protocol() + |> WebSockAdapter.upgrade(WebsocketHandler, arg, opts) + |> halt() + rescue + e in WebSockAdapter.UpgradeError -> + conn + |> send_resp(400, e.message) + |> halt() + end + + :error -> + conn + |> send_resp(403, "") + |> halt() + + {:error, reason} -> + {m, f, args} = opts[:error_handler] + + halt(apply(m, f, [conn, reason | args])) + end + end + + defp echo_sec_websocket_protocol(conn) do + case get_req_header(conn, "sec-websocket-protocol") do + [protocols | _] -> + case Plug.Conn.Utils.list(protocols) do + [protocol | _] -> put_resp_header(conn, "sec-websocket-protocol", protocol) + nil -> conn + [] -> conn + end + + [] -> + conn + end + end +end diff --git a/test/pleroma/integration/mastodon_websocket_test.exs b/test/pleroma/integration/mastodon_websocket_test.exs index 078bb643c..de88e5002 100644 --- a/test/pleroma/integration/mastodon_websocket_test.exs +++ b/test/pleroma/integration/mastodon_websocket_test.exs @@ -11,6 +11,7 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do alias Pleroma.Integration.WebsocketClient alias Pleroma.Web.CommonAPI + alias Pleroma.Web.MastodonAPI.StatusView alias Pleroma.Web.OAuth @moduletag needs_streamer: true, capture_log: true @@ -31,6 +32,48 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do WebsocketClient.start_link(self(), path, headers) end + defp raw_websocket_handshake(qs, headers) do + uri = URI.parse(@path <> qs) + port = uri.port || 80 + path = uri.path <> if(uri.query, do: "?" <> uri.query, else: "") + + default_headers = [ + {"host", "#{uri.host}:#{port}"}, + {"upgrade", "websocket"}, + {"connection", "Upgrade"}, + {"sec-websocket-key", Base.encode64(:crypto.strong_rand_bytes(16))}, + {"sec-websocket-version", "13"} + ] + + request = [ + "GET #{path} HTTP/1.1\r\n", + Enum.map(default_headers ++ headers, fn {name, value} -> "#{name}: #{value}\r\n" end), + "\r\n" + ] + + with {:ok, socket} <- + :gen_tcp.connect(String.to_charlist(uri.host), port, [:binary, active: false], 1_000), + :ok <- :gen_tcp.send(socket, request), + {:ok, response} <- :gen_tcp.recv(socket, 0, 1_000) do + :gen_tcp.close(socket) + {:ok, parse_http_response(response)} + end + end + + defp parse_http_response(response) do + [headers | _] = String.split(response, "\r\n\r\n", parts: 2) + [status_line | header_lines] = String.split(headers, "\r\n") + [_, status | _] = String.split(status_line, " ") + + headers = + Enum.map(header_lines, fn line -> + [name, value] = String.split(line, ":", parts: 2) + {String.downcase(name), String.trim(value)} + end) + + %{status: String.to_integer(status), headers: headers} + end + defp decode_json(json) do with {:ok, %{"event" => event, "payload" => payload_text}} <- Jason.decode(json), {:ok, payload} <- Jason.decode(payload_text) do @@ -85,9 +128,7 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do assert json["payload"] assert {:ok, json} = Jason.decode(json["payload"]) - view_json = - Pleroma.Web.MastodonAPI.StatusView.render("show.json", activity: activity, for: nil) - |> atom_key_to_string() + view_json = atom_key_to_string(StatusView.render("show.json", activity: activity, for: nil)) assert json == view_json end @@ -114,10 +155,7 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do assert json["payload"] assert {:ok, json} = Jason.decode(json["payload"]) - view_json = - Pleroma.Web.MastodonAPI.StatusView.render("show.json", activity: activity, for: nil) - |> Jason.encode!() - |> Jason.decode!() + view_json = atom_key_to_string(StatusView.render("show.json", activity: activity, for: nil)) assert json == view_json end @@ -279,6 +317,34 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do end) end + test "echoes the Sec-WebSocket-Protocol token in the handshake", %{token: token} do + assert {:ok, %{status: 101, headers: headers}} = + raw_websocket_handshake("?stream=user", [ + {"sec-websocket-protocol", token.token} + ]) + + assert {"sec-websocket-protocol", token.token} in headers + end + + test "echoes the selected Sec-WebSocket-Protocol token", %{token: token} do + assert {:ok, %{status: 101, headers: headers}} = + raw_websocket_handshake("?stream=user", [ + {"sec-websocket-protocol", "#{token.token}, phoenix"} + ]) + + assert {"sec-websocket-protocol", token.token} in headers + end + + test "does not echo an invalid Sec-WebSocket-Protocol token", %{token: token} do + assert {:ok, %{status: 401, headers: headers}} = + raw_websocket_handshake("?stream=user", [ + {"sec-websocket-protocol", "invalid"} + ]) + + refute {"sec-websocket-protocol", token.token} in headers + refute List.keymember?(headers, "sec-websocket-protocol", 0) + end + test "prefers sec-websocket-protocol token over query access_token", %{ token: token, user: user @@ -450,12 +516,12 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do assert {:ok, json} = Jason.decode(json["payload"]) view_json = - Pleroma.Web.MastodonAPI.StatusView.render("show.json", - activity: activity, - for: reading_user + atom_key_to_string( + StatusView.render("show.json", + activity: activity, + for: reading_user + ) ) - |> Jason.encode!() - |> Jason.decode!() assert json == view_json end @@ -478,12 +544,12 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do activity = Pleroma.Activity.normalize(activity) view_json = - Pleroma.Web.MastodonAPI.StatusView.render("show.json", - activity: activity, - for: reading_user + atom_key_to_string( + StatusView.render("show.json", + activity: activity, + for: reading_user + ) ) - |> Jason.encode!() - |> Jason.decode!() assert {:ok, %{"event" => "status.update", "payload" => ^view_json}} = decode_json(raw_json) end