Fix WebSocket protocol token handshakes
This commit is contained in:
parent
093b156c65
commit
9dd02ecd50
5 changed files with 200 additions and 28 deletions
1
changelog.d/mastodon-websocket-protocol.fix
Normal file
1
changelog.d/mastodon-websocket-protocol.fix
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
Echo Mastodon-style `Sec-WebSocket-Protocol` tokens in streaming WebSocket handshakes.
|
||||||
|
|
@ -9,8 +9,8 @@ defmodule Pleroma.Web.Endpoint do
|
||||||
|
|
||||||
alias Pleroma.Config
|
alias Pleroma.Config
|
||||||
|
|
||||||
socket("/api/v1/streaming", Pleroma.Web.MastodonAPI.WebsocketHandler,
|
plug(Pleroma.Web.MastodonAPI.WebsocketPlug,
|
||||||
longpoll: false,
|
path: "/api/v1/streaming",
|
||||||
websocket: [
|
websocket: [
|
||||||
path: "/",
|
path: "/",
|
||||||
compress: false,
|
compress: false,
|
||||||
|
|
@ -169,8 +169,7 @@ defmodule Pleroma.Web.Endpoint do
|
||||||
else: "pleroma_key"
|
else: "pleroma_key"
|
||||||
|
|
||||||
extra =
|
extra =
|
||||||
Config.get([__MODULE__, :extra_cookie_attrs])
|
Enum.join(Config.get([__MODULE__, :extra_cookie_attrs]), ";")
|
||||||
|> Enum.join(";")
|
|
||||||
|
|
||||||
# The session will be stored in the cookie and signed,
|
# The session will be stored in the cookie and signed,
|
||||||
# this means its contents can be read but not tampered with.
|
# this means its contents can be read but not tampered with.
|
||||||
|
|
|
||||||
|
|
@ -67,9 +67,10 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do
|
||||||
|
|
||||||
@impl Phoenix.Socket.Transport
|
@impl Phoenix.Socket.Transport
|
||||||
def handle_in({text, [opcode: :text]}, state) do
|
def handle_in({text, [opcode: :text]}, state) do
|
||||||
with {:ok, %{} = event} <- Jason.decode(text) do
|
case Jason.decode(text) do
|
||||||
handle_client_event(event, state)
|
{:ok, %{} = event} ->
|
||||||
else
|
handle_client_event(event, state)
|
||||||
|
|
||||||
_ ->
|
_ ->
|
||||||
Logger.error("#{__MODULE__} received non-JSON event: #{inspect(text)}")
|
Logger.error("#{__MODULE__} received non-JSON event: #{inspect(text)}")
|
||||||
{:ok, state}
|
{:ok, state}
|
||||||
|
|
@ -85,11 +86,11 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do
|
||||||
def handle_info({:render_with_user, view, template, item, topic}, state) do
|
def handle_info({:render_with_user, view, template, item, topic}, state) do
|
||||||
user = %User{} = User.get_cached_by_ap_id(state.user.ap_id)
|
user = %User{} = User.get_cached_by_ap_id(state.user.ap_id)
|
||||||
|
|
||||||
unless Streamer.filtered_by_user?(user, item) do
|
if Streamer.filtered_by_user?(user, item) do
|
||||||
|
{:ok, state}
|
||||||
|
else
|
||||||
message = view.render(template, item, user, topic)
|
message = view.render(template, item, user, topic)
|
||||||
{:push, {:text, message}, %{state | user: user}}
|
{:push, {:text, message}, %{state | user: user}}
|
||||||
else
|
|
||||||
{:ok, state}
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
@ -253,7 +254,7 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do
|
||||||
|
|
||||||
defp find_sec_websocket_protocol(sec_headers) do
|
defp find_sec_websocket_protocol(sec_headers) do
|
||||||
Enum.find_value(sec_headers, fn
|
Enum.find_value(sec_headers, fn
|
||||||
{"sec-websocket-protocol", token} -> token
|
{"sec-websocket-protocol", protocols} -> protocols |> Plug.Conn.Utils.list() |> List.first()
|
||||||
_ -> nil
|
_ -> nil
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
|
||||||
105
lib/pleroma/web/mastodon_api/websocket_plug.ex
Normal file
105
lib/pleroma/web/mastodon_api/websocket_plug.ex
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Pleroma: A lightweight social networking server
|
||||||
|
# Copyright © 2017-2022 Pleroma Authors <https://pleroma.social/>
|
||||||
|
# SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
|
||||||
|
defmodule Pleroma.Web.MastodonAPI.WebsocketPlug do
|
||||||
|
@moduledoc """
|
||||||
|
A Phoenix 1.8 compatible WebSocket transport for Mastodon streaming.
|
||||||
|
|
||||||
|
It mirrors Phoenix.Transports.WebSocket, but echoes a successfully authenticated
|
||||||
|
Mastodon-style Sec-WebSocket-Protocol token so browser clients accept the handshake.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@behaviour Plug
|
||||||
|
|
||||||
|
import Plug.Conn
|
||||||
|
|
||||||
|
alias Phoenix.Socket.Transport
|
||||||
|
alias Pleroma.Web.Endpoint
|
||||||
|
alias Pleroma.Web.MastodonAPI.WebsocketHandler
|
||||||
|
|
||||||
|
@connect_info_opts [:check_csrf]
|
||||||
|
|
||||||
|
@impl Plug
|
||||||
|
def init(opts) do
|
||||||
|
path = String.split(Keyword.fetch!(opts, :path), "/", trim: true)
|
||||||
|
websocket = Keyword.fetch!(opts, :websocket)
|
||||||
|
config = Transport.load_config(websocket, Phoenix.Transports.WebSocket)
|
||||||
|
|
||||||
|
{path, config}
|
||||||
|
end
|
||||||
|
|
||||||
|
@impl Plug
|
||||||
|
def call(%{method: "GET", path_info: path} = conn, {path, opts}) do
|
||||||
|
conn
|
||||||
|
|> fetch_query_params()
|
||||||
|
|> Transport.code_reload(Endpoint, opts)
|
||||||
|
|> Transport.transport_log(opts[:transport_log])
|
||||||
|
|> Transport.check_origin(WebsocketHandler, Endpoint, opts)
|
||||||
|
|> connect(opts)
|
||||||
|
end
|
||||||
|
|
||||||
|
def call(%{path_info: path} = conn, {path, _opts}) do
|
||||||
|
conn
|
||||||
|
|> send_resp(400, "")
|
||||||
|
|> halt()
|
||||||
|
end
|
||||||
|
|
||||||
|
def call(conn, _opts), do: conn
|
||||||
|
|
||||||
|
defp connect(%{halted: true} = conn, _opts), do: conn
|
||||||
|
|
||||||
|
defp connect(%{params: params} = conn, opts) do
|
||||||
|
keys = Keyword.get(opts, :connect_info, [])
|
||||||
|
|
||||||
|
connect_info =
|
||||||
|
Transport.connect_info(conn, Endpoint, keys, Keyword.take(opts, @connect_info_opts))
|
||||||
|
|
||||||
|
config = %{
|
||||||
|
endpoint: Endpoint,
|
||||||
|
transport: :websocket,
|
||||||
|
options: opts,
|
||||||
|
params: params,
|
||||||
|
connect_info: connect_info
|
||||||
|
}
|
||||||
|
|
||||||
|
case WebsocketHandler.connect(config) do
|
||||||
|
{:ok, arg} ->
|
||||||
|
try do
|
||||||
|
conn
|
||||||
|
|> echo_sec_websocket_protocol()
|
||||||
|
|> WebSockAdapter.upgrade(WebsocketHandler, arg, opts)
|
||||||
|
|> halt()
|
||||||
|
rescue
|
||||||
|
e in WebSockAdapter.UpgradeError ->
|
||||||
|
conn
|
||||||
|
|> send_resp(400, e.message)
|
||||||
|
|> halt()
|
||||||
|
end
|
||||||
|
|
||||||
|
:error ->
|
||||||
|
conn
|
||||||
|
|> send_resp(403, "")
|
||||||
|
|> halt()
|
||||||
|
|
||||||
|
{:error, reason} ->
|
||||||
|
{m, f, args} = opts[:error_handler]
|
||||||
|
|
||||||
|
halt(apply(m, f, [conn, reason | args]))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
defp echo_sec_websocket_protocol(conn) do
|
||||||
|
case get_req_header(conn, "sec-websocket-protocol") do
|
||||||
|
[protocols | _] ->
|
||||||
|
case Plug.Conn.Utils.list(protocols) do
|
||||||
|
[protocol | _] -> put_resp_header(conn, "sec-websocket-protocol", protocol)
|
||||||
|
nil -> conn
|
||||||
|
[] -> conn
|
||||||
|
end
|
||||||
|
|
||||||
|
[] ->
|
||||||
|
conn
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
@ -11,6 +11,7 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do
|
||||||
|
|
||||||
alias Pleroma.Integration.WebsocketClient
|
alias Pleroma.Integration.WebsocketClient
|
||||||
alias Pleroma.Web.CommonAPI
|
alias Pleroma.Web.CommonAPI
|
||||||
|
alias Pleroma.Web.MastodonAPI.StatusView
|
||||||
alias Pleroma.Web.OAuth
|
alias Pleroma.Web.OAuth
|
||||||
|
|
||||||
@moduletag needs_streamer: true, capture_log: true
|
@moduletag needs_streamer: true, capture_log: true
|
||||||
|
|
@ -31,6 +32,48 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do
|
||||||
WebsocketClient.start_link(self(), path, headers)
|
WebsocketClient.start_link(self(), path, headers)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
defp raw_websocket_handshake(qs, headers) do
|
||||||
|
uri = URI.parse(@path <> qs)
|
||||||
|
port = uri.port || 80
|
||||||
|
path = uri.path <> if(uri.query, do: "?" <> uri.query, else: "")
|
||||||
|
|
||||||
|
default_headers = [
|
||||||
|
{"host", "#{uri.host}:#{port}"},
|
||||||
|
{"upgrade", "websocket"},
|
||||||
|
{"connection", "Upgrade"},
|
||||||
|
{"sec-websocket-key", Base.encode64(:crypto.strong_rand_bytes(16))},
|
||||||
|
{"sec-websocket-version", "13"}
|
||||||
|
]
|
||||||
|
|
||||||
|
request = [
|
||||||
|
"GET #{path} HTTP/1.1\r\n",
|
||||||
|
Enum.map(default_headers ++ headers, fn {name, value} -> "#{name}: #{value}\r\n" end),
|
||||||
|
"\r\n"
|
||||||
|
]
|
||||||
|
|
||||||
|
with {:ok, socket} <-
|
||||||
|
:gen_tcp.connect(String.to_charlist(uri.host), port, [:binary, active: false], 1_000),
|
||||||
|
:ok <- :gen_tcp.send(socket, request),
|
||||||
|
{:ok, response} <- :gen_tcp.recv(socket, 0, 1_000) do
|
||||||
|
:gen_tcp.close(socket)
|
||||||
|
{:ok, parse_http_response(response)}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
defp parse_http_response(response) do
|
||||||
|
[headers | _] = String.split(response, "\r\n\r\n", parts: 2)
|
||||||
|
[status_line | header_lines] = String.split(headers, "\r\n")
|
||||||
|
[_, status | _] = String.split(status_line, " ")
|
||||||
|
|
||||||
|
headers =
|
||||||
|
Enum.map(header_lines, fn line ->
|
||||||
|
[name, value] = String.split(line, ":", parts: 2)
|
||||||
|
{String.downcase(name), String.trim(value)}
|
||||||
|
end)
|
||||||
|
|
||||||
|
%{status: String.to_integer(status), headers: headers}
|
||||||
|
end
|
||||||
|
|
||||||
defp decode_json(json) do
|
defp decode_json(json) do
|
||||||
with {:ok, %{"event" => event, "payload" => payload_text}} <- Jason.decode(json),
|
with {:ok, %{"event" => event, "payload" => payload_text}} <- Jason.decode(json),
|
||||||
{:ok, payload} <- Jason.decode(payload_text) do
|
{:ok, payload} <- Jason.decode(payload_text) do
|
||||||
|
|
@ -85,9 +128,7 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do
|
||||||
assert json["payload"]
|
assert json["payload"]
|
||||||
assert {:ok, json} = Jason.decode(json["payload"])
|
assert {:ok, json} = Jason.decode(json["payload"])
|
||||||
|
|
||||||
view_json =
|
view_json = atom_key_to_string(StatusView.render("show.json", activity: activity, for: nil))
|
||||||
Pleroma.Web.MastodonAPI.StatusView.render("show.json", activity: activity, for: nil)
|
|
||||||
|> atom_key_to_string()
|
|
||||||
|
|
||||||
assert json == view_json
|
assert json == view_json
|
||||||
end
|
end
|
||||||
|
|
@ -114,10 +155,7 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do
|
||||||
assert json["payload"]
|
assert json["payload"]
|
||||||
assert {:ok, json} = Jason.decode(json["payload"])
|
assert {:ok, json} = Jason.decode(json["payload"])
|
||||||
|
|
||||||
view_json =
|
view_json = atom_key_to_string(StatusView.render("show.json", activity: activity, for: nil))
|
||||||
Pleroma.Web.MastodonAPI.StatusView.render("show.json", activity: activity, for: nil)
|
|
||||||
|> Jason.encode!()
|
|
||||||
|> Jason.decode!()
|
|
||||||
|
|
||||||
assert json == view_json
|
assert json == view_json
|
||||||
end
|
end
|
||||||
|
|
@ -279,6 +317,34 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
test "echoes the Sec-WebSocket-Protocol token in the handshake", %{token: token} do
|
||||||
|
assert {:ok, %{status: 101, headers: headers}} =
|
||||||
|
raw_websocket_handshake("?stream=user", [
|
||||||
|
{"sec-websocket-protocol", token.token}
|
||||||
|
])
|
||||||
|
|
||||||
|
assert {"sec-websocket-protocol", token.token} in headers
|
||||||
|
end
|
||||||
|
|
||||||
|
test "echoes the selected Sec-WebSocket-Protocol token", %{token: token} do
|
||||||
|
assert {:ok, %{status: 101, headers: headers}} =
|
||||||
|
raw_websocket_handshake("?stream=user", [
|
||||||
|
{"sec-websocket-protocol", "#{token.token}, phoenix"}
|
||||||
|
])
|
||||||
|
|
||||||
|
assert {"sec-websocket-protocol", token.token} in headers
|
||||||
|
end
|
||||||
|
|
||||||
|
test "does not echo an invalid Sec-WebSocket-Protocol token", %{token: token} do
|
||||||
|
assert {:ok, %{status: 401, headers: headers}} =
|
||||||
|
raw_websocket_handshake("?stream=user", [
|
||||||
|
{"sec-websocket-protocol", "invalid"}
|
||||||
|
])
|
||||||
|
|
||||||
|
refute {"sec-websocket-protocol", token.token} in headers
|
||||||
|
refute List.keymember?(headers, "sec-websocket-protocol", 0)
|
||||||
|
end
|
||||||
|
|
||||||
test "prefers sec-websocket-protocol token over query access_token", %{
|
test "prefers sec-websocket-protocol token over query access_token", %{
|
||||||
token: token,
|
token: token,
|
||||||
user: user
|
user: user
|
||||||
|
|
@ -450,12 +516,12 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do
|
||||||
assert {:ok, json} = Jason.decode(json["payload"])
|
assert {:ok, json} = Jason.decode(json["payload"])
|
||||||
|
|
||||||
view_json =
|
view_json =
|
||||||
Pleroma.Web.MastodonAPI.StatusView.render("show.json",
|
atom_key_to_string(
|
||||||
activity: activity,
|
StatusView.render("show.json",
|
||||||
for: reading_user
|
activity: activity,
|
||||||
|
for: reading_user
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|> Jason.encode!()
|
|
||||||
|> Jason.decode!()
|
|
||||||
|
|
||||||
assert json == view_json
|
assert json == view_json
|
||||||
end
|
end
|
||||||
|
|
@ -478,12 +544,12 @@ defmodule Pleroma.Integration.MastodonWebsocketTest do
|
||||||
activity = Pleroma.Activity.normalize(activity)
|
activity = Pleroma.Activity.normalize(activity)
|
||||||
|
|
||||||
view_json =
|
view_json =
|
||||||
Pleroma.Web.MastodonAPI.StatusView.render("show.json",
|
atom_key_to_string(
|
||||||
activity: activity,
|
StatusView.render("show.json",
|
||||||
for: reading_user
|
activity: activity,
|
||||||
|
for: reading_user
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|> Jason.encode!()
|
|
||||||
|> Jason.decode!()
|
|
||||||
|
|
||||||
assert {:ok, %{"event" => "status.update", "payload" => ^view_json}} = decode_json(raw_json)
|
assert {:ok, %{"event" => "status.update", "payload" => ^view_json}} = decode_json(raw_json)
|
||||||
end
|
end
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue