Fix WebSocket protocol token handshakes

This commit is contained in:
Lain Soykaf 2026-05-22 16:30:14 +04:00
commit 9dd02ecd50
No known key found for this signature in database
5 changed files with 200 additions and 28 deletions

View file

@ -9,8 +9,8 @@ defmodule Pleroma.Web.Endpoint do
alias Pleroma.Config
socket("/api/v1/streaming", Pleroma.Web.MastodonAPI.WebsocketHandler,
longpoll: false,
plug(Pleroma.Web.MastodonAPI.WebsocketPlug,
path: "/api/v1/streaming",
websocket: [
path: "/",
compress: false,
@ -169,8 +169,7 @@ defmodule Pleroma.Web.Endpoint do
else: "pleroma_key"
extra =
Config.get([__MODULE__, :extra_cookie_attrs])
|> Enum.join(";")
Enum.join(Config.get([__MODULE__, :extra_cookie_attrs]), ";")
# The session will be stored in the cookie and signed,
# this means its contents can be read but not tampered with.

View file

@ -67,9 +67,10 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do
@impl Phoenix.Socket.Transport
def handle_in({text, [opcode: :text]}, state) do
with {:ok, %{} = event} <- Jason.decode(text) do
handle_client_event(event, state)
else
case Jason.decode(text) do
{:ok, %{} = event} ->
handle_client_event(event, state)
_ ->
Logger.error("#{__MODULE__} received non-JSON event: #{inspect(text)}")
{:ok, state}
@ -85,11 +86,11 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do
def handle_info({:render_with_user, view, template, item, topic}, state) do
user = %User{} = User.get_cached_by_ap_id(state.user.ap_id)
unless Streamer.filtered_by_user?(user, item) do
if Streamer.filtered_by_user?(user, item) do
{:ok, state}
else
message = view.render(template, item, user, topic)
{:push, {:text, message}, %{state | user: user}}
else
{:ok, state}
end
end
@ -253,7 +254,7 @@ defmodule Pleroma.Web.MastodonAPI.WebsocketHandler do
defp find_sec_websocket_protocol(sec_headers) do
Enum.find_value(sec_headers, fn
{"sec-websocket-protocol", token} -> token
{"sec-websocket-protocol", protocols} -> protocols |> Plug.Conn.Utils.list() |> List.first()
_ -> nil
end)
end

View file

@ -0,0 +1,105 @@
# Pleroma: A lightweight social networking server
# Copyright © 2017-2022 Pleroma Authors <https://pleroma.social/>
# SPDX-License-Identifier: AGPL-3.0-only
defmodule Pleroma.Web.MastodonAPI.WebsocketPlug do
@moduledoc """
A Phoenix 1.8 compatible WebSocket transport for Mastodon streaming.
It mirrors Phoenix.Transports.WebSocket, but echoes a successfully authenticated
Mastodon-style Sec-WebSocket-Protocol token so browser clients accept the handshake.
"""
@behaviour Plug
import Plug.Conn
alias Phoenix.Socket.Transport
alias Pleroma.Web.Endpoint
alias Pleroma.Web.MastodonAPI.WebsocketHandler
@connect_info_opts [:check_csrf]
@impl Plug
def init(opts) do
path = String.split(Keyword.fetch!(opts, :path), "/", trim: true)
websocket = Keyword.fetch!(opts, :websocket)
config = Transport.load_config(websocket, Phoenix.Transports.WebSocket)
{path, config}
end
@impl Plug
def call(%{method: "GET", path_info: path} = conn, {path, opts}) do
conn
|> fetch_query_params()
|> Transport.code_reload(Endpoint, opts)
|> Transport.transport_log(opts[:transport_log])
|> Transport.check_origin(WebsocketHandler, Endpoint, opts)
|> connect(opts)
end
def call(%{path_info: path} = conn, {path, _opts}) do
conn
|> send_resp(400, "")
|> halt()
end
def call(conn, _opts), do: conn
defp connect(%{halted: true} = conn, _opts), do: conn
defp connect(%{params: params} = conn, opts) do
keys = Keyword.get(opts, :connect_info, [])
connect_info =
Transport.connect_info(conn, Endpoint, keys, Keyword.take(opts, @connect_info_opts))
config = %{
endpoint: Endpoint,
transport: :websocket,
options: opts,
params: params,
connect_info: connect_info
}
case WebsocketHandler.connect(config) do
{:ok, arg} ->
try do
conn
|> echo_sec_websocket_protocol()
|> WebSockAdapter.upgrade(WebsocketHandler, arg, opts)
|> halt()
rescue
e in WebSockAdapter.UpgradeError ->
conn
|> send_resp(400, e.message)
|> halt()
end
:error ->
conn
|> send_resp(403, "")
|> halt()
{:error, reason} ->
{m, f, args} = opts[:error_handler]
halt(apply(m, f, [conn, reason | args]))
end
end
defp echo_sec_websocket_protocol(conn) do
case get_req_header(conn, "sec-websocket-protocol") do
[protocols | _] ->
case Plug.Conn.Utils.list(protocols) do
[protocol | _] -> put_resp_header(conn, "sec-websocket-protocol", protocol)
nil -> conn
[] -> conn
end
[] ->
conn
end
end
end