refactor(ext/websocket): use concrete error type (#26226)

This commit is contained in:
Leo Kettmeir 2024-10-18 12:30:46 -07:00 committed by GitHub
parent 4b99cde504
commit d047cab14b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 172 additions and 97 deletions

1
Cargo.lock generated
View File

@ -2225,6 +2225,7 @@ dependencies = [
"once_cell",
"rustls-tokio-stream",
"serde",
"thiserror",
"tokio",
]

View File

@ -246,7 +246,11 @@ pub async fn op_http_upgrade_websocket_next(
// Stage 3: take the extracted raw network stream and upgrade it to a websocket, then return it
let (stream, bytes) = extract_network_stream(upgraded);
ws_create_server_stream(&mut state.borrow_mut(), stream, bytes)
Ok(ws_create_server_stream(
&mut state.borrow_mut(),
stream,
bytes,
))
}
#[op2(fast)]

View File

@ -1053,9 +1053,11 @@ async fn op_http_upgrade_websocket(
let (transport, bytes) =
extract_network_stream(hyper_v014::upgrade::on(request).await?);
let ws_rid =
ws_create_server_stream(&mut state.borrow_mut(), transport, bytes)?;
Ok(ws_rid)
Ok(ws_create_server_stream(
&mut state.borrow_mut(),
transport,
bytes,
))
}
// Needed so hyper can use non Send futures

View File

@ -28,4 +28,5 @@ hyper-util.workspace = true
once_cell.workspace = true
rustls-tokio-stream.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio.workspace = true

View File

@ -1,10 +1,6 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use crate::stream::WebSocketStream;
use bytes::Bytes;
use deno_core::anyhow::bail;
use deno_core::error::invalid_hostname;
use deno_core::error::type_error;
use deno_core::error::AnyError;
use deno_core::futures::TryFutureExt;
use deno_core::op2;
use deno_core::unsync::spawn;
@ -43,7 +39,6 @@ use serde::Serialize;
use std::borrow::Cow;
use std::cell::Cell;
use std::cell::RefCell;
use std::fmt;
use std::future::Future;
use std::num::NonZeroUsize;
use std::path::PathBuf;
@ -75,11 +70,33 @@ static USE_WRITEV: Lazy<bool> = Lazy::new(|| {
false
});
#[derive(Debug, thiserror::Error)]
pub enum WebsocketError {
#[error(transparent)]
Url(url::ParseError),
#[error(transparent)]
Permission(deno_core::error::AnyError),
#[error(transparent)]
Resource(deno_core::error::AnyError),
#[error(transparent)]
Uri(#[from] http::uri::InvalidUri),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error(transparent)]
WebSocket(#[from] fastwebsockets::WebSocketError),
#[error("failed to connect to WebSocket: {0}")]
ConnectionFailed(#[from] HandshakeError),
#[error(transparent)]
Canceled(#[from] deno_core::Canceled),
}
#[derive(Clone)]
pub struct WsRootStoreProvider(Option<Arc<dyn RootCertStoreProvider>>);
impl WsRootStoreProvider {
pub fn get_or_try_init(&self) -> Result<Option<RootCertStore>, AnyError> {
pub fn get_or_try_init(
&self,
) -> Result<Option<RootCertStore>, deno_core::error::AnyError> {
Ok(match &self.0 {
Some(provider) => Some(provider.get_or_try_init()?.clone()),
None => None,
@ -95,7 +112,7 @@ pub trait WebSocketPermissions {
&mut self,
_url: &url::Url,
_api_name: &str,
) -> Result<(), AnyError>;
) -> Result<(), deno_core::error::AnyError>;
}
impl WebSocketPermissions for deno_permissions::PermissionsContainer {
@ -104,7 +121,7 @@ impl WebSocketPermissions for deno_permissions::PermissionsContainer {
&mut self,
url: &url::Url,
api_name: &str,
) -> Result<(), AnyError> {
) -> Result<(), deno_core::error::AnyError> {
deno_permissions::PermissionsContainer::check_net_url(self, url, api_name)
}
}
@ -137,13 +154,17 @@ pub fn op_ws_check_permission_and_cancel_handle<WP>(
#[string] api_name: String,
#[string] url: String,
cancel_handle: bool,
) -> Result<Option<ResourceId>, AnyError>
) -> Result<Option<ResourceId>, WebsocketError>
where
WP: WebSocketPermissions + 'static,
{
state
.borrow_mut::<WP>()
.check_net_url(&url::Url::parse(&url)?, &api_name)?;
.check_net_url(
&url::Url::parse(&url).map_err(WebsocketError::Url)?,
&api_name,
)
.map_err(WebsocketError::Permission)?;
if cancel_handle {
let rid = state
@ -163,16 +184,46 @@ pub struct CreateResponse {
extensions: String,
}
#[derive(Debug, thiserror::Error)]
pub enum HandshakeError {
#[error("Missing path in url")]
MissingPath,
#[error("Invalid status code {0}")]
InvalidStatusCode(StatusCode),
#[error(transparent)]
Http(#[from] http::Error),
#[error(transparent)]
WebSocket(#[from] fastwebsockets::WebSocketError),
#[error("Didn't receive h2 alpn, aborting connection")]
NoH2Alpn,
#[error(transparent)]
Rustls(#[from] deno_tls::rustls::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
H2(#[from] h2::Error),
#[error("Invalid hostname: '{0}'")]
InvalidHostname(String),
#[error(transparent)]
RootStoreError(deno_core::error::AnyError),
#[error(transparent)]
Tls(deno_tls::TlsError),
#[error(transparent)]
HeaderName(#[from] http::header::InvalidHeaderName),
#[error(transparent)]
HeaderValue(#[from] http::header::InvalidHeaderValue),
}
async fn handshake_websocket(
state: &Rc<RefCell<OpState>>,
uri: &Uri,
protocols: &str,
headers: Option<Vec<(ByteString, ByteString)>>,
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let mut request = Request::builder().method(Method::GET).uri(
uri
.path_and_query()
.ok_or(type_error("Missing path in url".to_string()))?
.ok_or(HandshakeError::MissingPath)?
.as_str(),
);
@ -194,7 +245,9 @@ async fn handshake_websocket(
request =
populate_common_request_headers(request, &user_agent, protocols, &headers)?;
let request = request.body(http_body_util::Empty::new())?;
let request = request
.body(http_body_util::Empty::new())
.map_err(HandshakeError::Http)?;
let domain = &uri.host().unwrap().to_string();
let port = &uri.port_u16().unwrap_or(match uri.scheme_str() {
Some("wss") => 443,
@ -231,7 +284,7 @@ async fn handshake_websocket(
async fn handshake_http1_ws(
request: Request<http_body_util::Empty<Bytes>>,
addr: &String,
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let tcp_socket = TcpStream::connect(addr).await?;
handshake_connection(request, tcp_socket).await
}
@ -241,11 +294,11 @@ async fn handshake_http1_wss(
request: Request<http_body_util::Empty<Bytes>>,
domain: &str,
addr: &str,
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let tcp_socket = TcpStream::connect(addr).await?;
let tls_config = create_ws_client_config(state, SocketUse::Http1Only)?;
let dnsname = ServerName::try_from(domain.to_string())
.map_err(|_| invalid_hostname(domain))?;
.map_err(|_| HandshakeError::InvalidHostname(domain.to_string()))?;
let mut tls_connector = TlsStream::new_client_side(
tcp_socket,
ClientConnection::new(tls_config.into(), dnsname)?,
@ -266,11 +319,11 @@ async fn handshake_http2_wss(
domain: &str,
headers: &Option<Vec<(ByteString, ByteString)>>,
addr: &str,
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let tcp_socket = TcpStream::connect(addr).await?;
let tls_config = create_ws_client_config(state, SocketUse::Http2Only)?;
let dnsname = ServerName::try_from(domain.to_string())
.map_err(|_| invalid_hostname(domain))?;
.map_err(|_| HandshakeError::InvalidHostname(domain.to_string()))?;
// We need to better expose the underlying errors here
let mut tls_connector = TlsStream::new_client_side(
tcp_socket,
@ -279,7 +332,7 @@ async fn handshake_http2_wss(
);
let handshake = tls_connector.handshake().await?;
if handshake.alpn.is_none() {
bail!("Didn't receive h2 alpn, aborting connection");
return Err(HandshakeError::NoH2Alpn);
}
let h2 = h2::client::Builder::new();
let (mut send, conn) = h2.handshake::<_, Bytes>(tls_connector).await?;
@ -298,7 +351,7 @@ async fn handshake_http2_wss(
let (resp, send) = send.send_request(request.body(())?, false)?;
let resp = resp.await?;
if resp.status() != StatusCode::OK {
bail!("Invalid status code: {}", resp.status());
return Err(HandshakeError::InvalidStatusCode(resp.status()));
}
let (http::response::Parts { headers, .. }, recv) = resp.into_parts();
let mut stream = WebSocket::after_handshake(
@ -317,7 +370,7 @@ async fn handshake_connection<
>(
request: Request<http_body_util::Empty<Bytes>>,
socket: S,
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let (upgraded, response) =
fastwebsockets::handshake::client(&LocalExecutor, request, socket).await?;
@ -332,7 +385,7 @@ async fn handshake_connection<
pub fn create_ws_client_config(
state: &Rc<RefCell<OpState>>,
socket_use: SocketUse,
) -> Result<ClientConfig, AnyError> {
) -> Result<ClientConfig, HandshakeError> {
let unsafely_ignore_certificate_errors: Option<Vec<String>> = state
.borrow()
.try_borrow::<UnsafelyIgnoreCertificateErrors>()
@ -340,7 +393,8 @@ pub fn create_ws_client_config(
let root_cert_store = state
.borrow()
.borrow::<WsRootStoreProvider>()
.get_or_try_init()?;
.get_or_try_init()
.map_err(HandshakeError::RootStoreError)?;
create_client_config(
root_cert_store,
@ -349,7 +403,7 @@ pub fn create_ws_client_config(
TlsKeys::Null,
socket_use,
)
.map_err(|e| e.into())
.map_err(HandshakeError::Tls)
}
/// Headers common to both http/1.1 and h2 requests.
@ -358,7 +412,7 @@ fn populate_common_request_headers(
user_agent: &str,
protocols: &str,
headers: &Option<Vec<(ByteString, ByteString)>>,
) -> Result<http::request::Builder, AnyError> {
) -> Result<http::request::Builder, HandshakeError> {
request = request
.header("User-Agent", user_agent)
.header("Sec-WebSocket-Version", "13");
@ -369,10 +423,8 @@ fn populate_common_request_headers(
if let Some(headers) = headers {
for (key, value) in headers {
let name = HeaderName::from_bytes(key)
.map_err(|err| type_error(err.to_string()))?;
let v = HeaderValue::from_bytes(value)
.map_err(|err| type_error(err.to_string()))?;
let name = HeaderName::from_bytes(key)?;
let v = HeaderValue::from_bytes(value)?;
let is_disallowed_header = matches!(
name,
@ -402,14 +454,17 @@ pub async fn op_ws_create<WP>(
#[string] protocols: String,
#[smi] cancel_handle: Option<ResourceId>,
#[serde] headers: Option<Vec<(ByteString, ByteString)>>,
) -> Result<CreateResponse, AnyError>
) -> Result<CreateResponse, WebsocketError>
where
WP: WebSocketPermissions + 'static,
{
{
let mut s = state.borrow_mut();
s.borrow_mut::<WP>()
.check_net_url(&url::Url::parse(&url)?, &api_name)
.check_net_url(
&url::Url::parse(&url).map_err(WebsocketError::Url)?,
&api_name,
)
.expect(
"Permission check should have been done in op_ws_check_permission",
);
@ -419,7 +474,8 @@ where
let r = state
.borrow_mut()
.resource_table
.get::<WsCancelResource>(cancel_rid)?;
.get::<WsCancelResource>(cancel_rid)
.map_err(WebsocketError::Resource)?;
Some(r.0.clone())
} else {
None
@ -428,15 +484,11 @@ where
let uri: Uri = url.parse()?;
let handshake = handshake_websocket(&state, &uri, &protocols, headers)
.map_err(|err| {
AnyError::from(DomExceptionNetworkError::new(&format!(
"failed to connect to WebSocket: {err}"
)))
});
.map_err(WebsocketError::ConnectionFailed);
let (stream, response) = match cancel_resource {
Some(rc) => handshake.try_or_cancel(rc).await,
None => handshake.await,
}?;
Some(rc) => handshake.try_or_cancel(rc).await?,
None => handshake.await?,
};
if let Some(cancel_rid) = cancel_handle {
if let Ok(res) = state.borrow_mut().resource_table.take_any(cancel_rid) {
@ -521,14 +573,12 @@ impl ServerWebSocket {
self: &Rc<Self>,
lock: AsyncMutFuture<WebSocketWrite<WriteHalf<WebSocketStream>>>,
frame: Frame<'_>,
) -> Result<(), AnyError> {
) -> Result<(), WebsocketError> {
let mut ws = lock.await;
if ws.is_closed() {
return Ok(());
}
ws.write_frame(frame)
.await
.map_err(|err| type_error(err.to_string()))?;
ws.write_frame(frame).await?;
Ok(())
}
}
@ -543,7 +593,7 @@ pub fn ws_create_server_stream(
state: &mut OpState,
transport: NetworkStream,
read_buf: Bytes,
) -> Result<ResourceId, AnyError> {
) -> ResourceId {
let mut ws = WebSocket::after_handshake(
WebSocketStream::new(
stream::WsStreamKind::Network(transport),
@ -555,8 +605,7 @@ pub fn ws_create_server_stream(
ws.set_auto_close(true);
ws.set_auto_pong(true);
let rid = state.resource_table.add(ServerWebSocket::new(ws));
Ok(rid)
state.resource_table.add(ServerWebSocket::new(ws))
}
fn send_binary(state: &mut OpState, rid: ResourceId, data: &[u8]) {
@ -626,11 +675,12 @@ pub async fn op_ws_send_binary_async(
state: Rc<RefCell<OpState>>,
#[smi] rid: ResourceId,
#[buffer] data: JsBuffer,
) -> Result<(), AnyError> {
) -> Result<(), WebsocketError> {
let resource = state
.borrow_mut()
.resource_table
.get::<ServerWebSocket>(rid)?;
.get::<ServerWebSocket>(rid)
.map_err(WebsocketError::Resource)?;
let data = data.to_vec();
let lock = resource.reserve_lock();
resource
@ -644,11 +694,12 @@ pub async fn op_ws_send_text_async(
state: Rc<RefCell<OpState>>,
#[smi] rid: ResourceId,
#[string] data: String,
) -> Result<(), AnyError> {
) -> Result<(), WebsocketError> {
let resource = state
.borrow_mut()
.resource_table
.get::<ServerWebSocket>(rid)?;
.get::<ServerWebSocket>(rid)
.map_err(WebsocketError::Resource)?;
let lock = resource.reserve_lock();
resource
.write_frame(
@ -678,11 +729,12 @@ pub fn op_ws_get_buffered_amount(
pub async fn op_ws_send_ping(
state: Rc<RefCell<OpState>>,
#[smi] rid: ResourceId,
) -> Result<(), AnyError> {
) -> Result<(), WebsocketError> {
let resource = state
.borrow_mut()
.resource_table
.get::<ServerWebSocket>(rid)?;
.get::<ServerWebSocket>(rid)
.map_err(WebsocketError::Resource)?;
let lock = resource.reserve_lock();
resource
.write_frame(
@ -698,7 +750,7 @@ pub async fn op_ws_close(
#[smi] rid: ResourceId,
#[smi] code: Option<u16>,
#[string] reason: Option<String>,
) -> Result<(), AnyError> {
) -> Result<(), WebsocketError> {
let Ok(resource) = state
.borrow_mut()
.resource_table
@ -713,8 +765,7 @@ pub async fn op_ws_close(
resource.closed.set(true);
let lock = resource.reserve_lock();
resource.write_frame(lock, frame).await?;
Ok(())
resource.write_frame(lock, frame).await
}
#[op2]
@ -868,32 +919,6 @@ pub fn get_declaration() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("lib.deno_websocket.d.ts")
}
#[derive(Debug)]
pub struct DomExceptionNetworkError {
pub msg: String,
}
impl DomExceptionNetworkError {
pub fn new(msg: &str) -> Self {
DomExceptionNetworkError {
msg: msg.to_string(),
}
}
}
impl fmt::Display for DomExceptionNetworkError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad(&self.msg)
}
}
impl std::error::Error for DomExceptionNetworkError {}
pub fn get_network_error_class_name(e: &AnyError) -> Option<&'static str> {
e.downcast_ref::<DomExceptionNetworkError>()
.map(|_| "DOMExceptionNetworkError")
}
// Needed so hyper can use non Send futures
#[derive(Clone)]
struct LocalExecutor;

View File

@ -36,6 +36,8 @@ use deno_web::CompressionError;
use deno_web::MessagePortError;
use deno_web::StreamResourceError;
use deno_web::WebError;
use deno_websocket::HandshakeError;
use deno_websocket::WebsocketError;
use deno_webstorage::WebStorageError;
use std::env;
use std::error::Error;
@ -368,6 +370,43 @@ fn get_broadcast_channel_error(error: &BroadcastChannelError) -> &'static str {
}
}
fn get_websocket_error(error: &WebsocketError) -> &'static str {
match error {
WebsocketError::Permission(e) | WebsocketError::Resource(e) => {
get_error_class_name(e).unwrap_or("Error")
}
WebsocketError::Url(e) => get_url_parse_error_class(e),
WebsocketError::Io(e) => get_io_error_class(e),
WebsocketError::WebSocket(_) => "TypeError",
WebsocketError::ConnectionFailed(_) => "DOMExceptionNetworkError",
WebsocketError::Uri(_) => "Error",
WebsocketError::Canceled(e) => {
let io_err: io::Error = e.to_owned().into();
get_io_error_class(&io_err)
}
}
}
fn get_websocket_handshake_error(error: &HandshakeError) -> &'static str {
match error {
HandshakeError::RootStoreError(e) => {
get_error_class_name(e).unwrap_or("Error")
}
HandshakeError::Tls(e) => get_tls_error_class(e),
HandshakeError::MissingPath => "TypeError",
HandshakeError::Http(_) => "Error",
HandshakeError::InvalidHostname(_) => "TypeError",
HandshakeError::Io(e) => get_io_error_class(e),
HandshakeError::Rustls(_) => "Error",
HandshakeError::H2(_) => "Error",
HandshakeError::NoH2Alpn => "Error",
HandshakeError::InvalidStatusCode(_) => "Error",
HandshakeError::WebSocket(_) => "TypeError",
HandshakeError::HeaderName(_) => "TypeError",
HandshakeError::HeaderValue(_) => "TypeError",
}
}
fn get_fs_error(error: &FsOpsError) -> &'static str {
match error {
FsOpsError::Io(e) => get_io_error_class(e),
@ -482,7 +521,6 @@ fn get_net_map_error(error: &deno_net::io::MapError) -> &'static str {
pub fn get_error_class_name(e: &AnyError) -> Option<&'static str> {
deno_core::error::get_custom_error_class(e)
.or_else(|| deno_webgpu::error::get_error_class_name(e))
.or_else(|| deno_websocket::get_network_error_class_name(e))
.or_else(|| e.downcast_ref::<NApiError>().map(get_napi_error_class))
.or_else(|| e.downcast_ref::<WebError>().map(get_web_error_class))
.or_else(|| {
@ -518,6 +556,11 @@ pub fn get_error_class_name(e: &AnyError) -> Option<&'static str> {
.or_else(|| e.downcast_ref::<CronError>().map(get_cron_error_class))
.or_else(|| e.downcast_ref::<CanvasError>().map(get_canvas_error))
.or_else(|| e.downcast_ref::<CacheError>().map(get_cache_error))
.or_else(|| e.downcast_ref::<WebsocketError>().map(get_websocket_error))
.or_else(|| {
e.downcast_ref::<HandshakeError>()
.map(get_websocket_handshake_error)
})
.or_else(|| e.downcast_ref::<KvError>().map(get_kv_error))
.or_else(|| e.downcast_ref::<NetError>().map(get_net_error))
.or_else(|| {

View File

@ -4,6 +4,7 @@ use std::sync::Arc;
use crate::web_worker::WebWorkerInternalHandle;
use crate::web_worker::WebWorkerType;
use deno_core::error::custom_error;
use deno_core::error::type_error;
use deno_core::error::AnyError;
use deno_core::futures::StreamExt;
@ -12,7 +13,6 @@ use deno_core::url::Url;
use deno_core::OpState;
use deno_fetch::data_url::DataUrl;
use deno_web::BlobStore;
use deno_websocket::DomExceptionNetworkError;
use http_body_util::BodyExt;
use hyper::body::Bytes;
use serde::Deserialize;
@ -151,17 +151,16 @@ pub fn op_worker_sync_fetch(
match mime_type.as_deref() {
Some("application/javascript" | "text/javascript") => {}
Some(mime_type) => {
return Err(
DomExceptionNetworkError {
msg: format!("Invalid MIME type {mime_type:?}."),
}
.into(),
)
return Err(custom_error(
"DOMExceptionNetworkError",
format!("Invalid MIME type {mime_type:?}."),
))
}
None => {
return Err(
DomExceptionNetworkError::new("Missing MIME type.").into(),
)
return Err(custom_error(
"DOMExceptionNetworkError",
"Missing MIME type.",
))
}
}
}