]> Untitled Git - bdk-cli/commitdiff
Retry Payjoin requests with other relays
authorMshehu5 <musheu@gmail.com>
Sun, 28 Jun 2026 16:01:25 +0000 (17:01 +0100)
committerMshehu5 <musheu@gmail.com>
Thu, 2 Jul 2026 11:30:31 +0000 (12:30 +0100)
Payjoin currently keeps using the relay selected during OHTTP key
bootstrapping. If that relay goes offline, the session fails even when
other relay URLs were supplied.

Select from the available relays for each OHTTP request and remember
failures for the rest of the command. Build a fresh request for every
attempt so failover does not reuse linkable OHTTP ciphertext.

src/payjoin/mod.rs
src/payjoin/ohttp.rs

index e811213e5814625d02030dc4470657f8101a9b8f..1658c7560bca6f3ae0a4702153276c75e72aa436 100644 (file)
@@ -23,13 +23,10 @@ use payjoin::send::v2::{
 };
 use payjoin::{HpkePublicKey, ImplementationError, UriExt};
 use serde_json::{json, to_string_pretty};
-use std::{
-    path::PathBuf,
-    sync::{Arc, Mutex},
-};
+use std::{path::PathBuf, sync::Arc};
 
 use crate::payjoin::db::{ReceiverPersister, SenderPersister, open_payjoin_db};
-use crate::payjoin::ohttp::{RelayManager, fetch_ohttp_keys};
+use crate::payjoin::ohttp::RelayManager;
 
 pub mod db;
 pub mod ohttp;
@@ -41,7 +38,7 @@ pub mod ohttp;
 /// [`PayjoinManager::proceed_receiver_session`].
 pub(crate) struct PayjoinManager<'a> {
     wallet: &'a mut Wallet,
-    relay_manager: Arc<Mutex<RelayManager>>,
+    relay_manager: RelayManager,
     db: Arc<crate::payjoin::db::Database>,
 }
 
@@ -110,7 +107,7 @@ impl<'a> PayjoinManager<'a> {
         wallet_name: &str,
     ) -> Result<Self, Error> {
         let db = open_payjoin_db(datadir, wallet_name)?;
-        let relay_manager = Arc::new(Mutex::new(RelayManager::new()));
+        let relay_manager = RelayManager::new();
 
         Ok(Self {
             wallet,
@@ -137,14 +134,8 @@ impl<'a> PayjoinManager<'a> {
             .collect::<Result<_, _>>()
             .map_err(|e| Error::Generic(format!("Failed to parse one or more OHTTP URLs: {e}")))?;
 
-        if ohttp_relays.is_empty() {
-            return Err(Error::Generic(
-                "At least one valid OHTTP relay must be provided.".into(),
-            ));
-        }
-
-        let ohttp_keys =
-            fetch_ohttp_keys(ohttp_relays, &directory, self.relay_manager.clone()).await?;
+        self.relay_manager.configure(ohttp_relays)?;
+        let ohttp_keys = self.relay_manager.fetch_ohttp_keys(&directory).await?;
 
         let persister = crate::payjoin::db::ReceiverPersister::new(self.db.clone())?;
 
@@ -152,20 +143,17 @@ impl<'a> PayjoinManager<'a> {
             .map(FeeRate::from_sat_per_kwu)
             .unwrap_or(FeeRate::BROADCAST_MIN);
 
-        let receiver = payjoin::receive::v2::ReceiverBuilder::new(
-            address.address,
-            directory,
-            ohttp_keys.ohttp_keys,
-        )?
-        .with_amount(payjoin::bitcoin::Amount::from_sat(amount))
-        .with_max_fee_rate(checked_max_fee_rate)
-        .build()
-        .save(&persister)
-        .map_err(|e| {
-            Error::Generic(format!(
-                "Failed to persister the receiver after initialization: {e}"
-            ))
-        })?;
+        let receiver =
+            payjoin::receive::v2::ReceiverBuilder::new(address.address, directory, ohttp_keys)?
+                .with_amount(payjoin::bitcoin::Amount::from_sat(amount))
+                .with_max_fee_rate(checked_max_fee_rate)
+                .build()
+                .save(&persister)
+                .map_err(|e| {
+                    Error::Generic(format!(
+                        "Failed to persister the receiver after initialization: {e}"
+                    ))
+                })?;
 
         let pj_uri = receiver.pj_uri();
         println!("Request Payjoin by sharing this Payjoin Uri:");
@@ -174,7 +162,6 @@ impl<'a> PayjoinManager<'a> {
         self.proceed_receiver_session(
             ReceiveSession::Initialized(receiver.clone()),
             &persister,
-            ohttp_keys.relay_url.to_string(),
             checked_max_fee_rate,
             blockchain_client,
         )
@@ -244,11 +231,7 @@ impl<'a> PayjoinManager<'a> {
                         Error::Generic(format!("Failed to parse one or more OHTTP URLs: {e}"))
                     })?;
 
-                if ohttp_relays.is_empty() {
-                    return Err(Error::Generic(
-                        "At least one valid OHTTP relay must be provided.".into(),
-                    ));
-                }
+                self.relay_manager.configure(ohttp_relays)?;
                 // Check for existing session with the same receiver pubkey
                 let receiver_pubkey = v2_param.receiver_pubkey();
                 let existing_session =
@@ -293,13 +276,8 @@ impl<'a> PayjoinManager<'a> {
                     (SendSession::WithReplyKey(sender), persister)
                 };
 
-                self.proceed_sender_session(
-                    sender_state,
-                    &persister,
-                    ohttp_relays,
-                    blockchain_client,
-                )
-                .await?
+                self.proceed_sender_session(sender_state, &persister, blockchain_client)
+                    .await?
             }
             _ => {
                 unimplemented!("Payjoin version not recognized.");
@@ -313,20 +291,13 @@ impl<'a> PayjoinManager<'a> {
         &mut self,
         session: ReceiveSession,
         persister: &impl SessionPersister<SessionEvent = ReceiverSessionEvent>,
-        relay: impl payjoin::IntoUrl,
         max_fee_rate: FeeRate,
         blockchain_client: &BlockchainClient,
     ) -> Result<(), Error> {
         match session {
             ReceiveSession::Initialized(proposal) => {
-                self.read_from_directory(
-                    proposal,
-                    persister,
-                    relay,
-                    max_fee_rate,
-                    blockchain_client,
-                )
-                .await
+                self.read_from_directory(proposal, persister, max_fee_rate, blockchain_client)
+                    .await
             }
             ReceiveSession::UncheckedOriginalPayload(proposal) => {
                 self.check_proposal(proposal, persister, max_fee_rate, blockchain_client)
@@ -382,14 +353,14 @@ impl<'a> PayjoinManager<'a> {
         &mut self,
         receiver: Receiver<Initialized>,
         persister: &impl SessionPersister<SessionEvent = ReceiverSessionEvent>,
-        relay: impl payjoin::IntoUrl,
         max_fee_rate: FeeRate,
         blockchain_client: &BlockchainClient,
     ) -> Result<(), Error> {
         let mut current_receiver_typestate = receiver;
         let next_receiver_typestate = loop {
-            let (req, context) = current_receiver_typestate.create_poll_request(relay.as_str())?;
-            let response = self.send_payjoin_post_request(req).await?;
+            let (response, context) = self
+                .post_via_relay(|relay| current_receiver_typestate.create_poll_request(relay))
+                .await?;
             let state_transition = current_receiver_typestate
                 .process_response(response.bytes().await?.to_vec().as_slice(), context)
                 .save(persister);
@@ -630,17 +601,9 @@ impl<'a> PayjoinManager<'a> {
         persister: &impl SessionPersister<SessionEvent = ReceiverSessionEvent>,
         blockchain_client: &BlockchainClient,
     ) -> Result<(), Error> {
-        let (req, ctx) = receiver
-            .create_post_request(
-                self.unwrap_relay_or_else_fetch(vec![], None::<&str>)
-                    .await?
-                    .as_str(),
-            )
-            .map_err(|e| {
-                Error::Generic(format!("Error occurred when creating a post request for sending final Payjoin proposal: {e}"))
-            })?;
-
-        let res = self.send_payjoin_post_request(req).await?;
+        let (res, ctx) = self
+            .post_via_relay(|relay| receiver.create_post_request(relay))
+            .await?;
         let payjoin_psbt = receiver.psbt().clone();
         let next_receiver_typestate = receiver
             .process_response(&res.bytes().await?, ctx)
@@ -731,28 +694,9 @@ impl<'a> PayjoinManager<'a> {
         receiver: Receiver<HasReplyableError>,
         persister: &impl SessionPersister<SessionEvent = ReceiverSessionEvent>,
     ) -> Result<(), Error> {
-        let (err_req, err_ctx) = receiver
-            .create_error_request(
-                self.unwrap_relay_or_else_fetch(vec![], None::<&str>)
-                    .await?
-                    .as_str(),
-            )
-            .map_err(|e| {
-                Error::Generic(format!(
-                    "Error occurred when creating a receiver error request: {}",
-                    e
-                ))
-            })?;
-
-        let err_response = match self.send_payjoin_post_request(err_req).await {
-            Ok(response) => response,
-            Err(e) => {
-                return Err(Error::Generic(format!(
-                    "Failed to post error request: {}",
-                    e
-                )));
-            }
-        };
+        let (err_response, err_ctx) = self
+            .post_via_relay(|relay| receiver.create_error_request(relay))
+            .await?;
 
         let err_bytes = match err_response.bytes().await {
             Ok(bytes) => bytes,
@@ -781,33 +725,16 @@ impl<'a> PayjoinManager<'a> {
         &self,
         session: SendSession,
         persister: &impl SessionPersister<SessionEvent = SenderSessionEvent>,
-        ohttp_relays: Vec<url::Url>,
         blockchain_client: &BlockchainClient,
     ) -> Result<Txid, Error> {
         match session {
             SendSession::WithReplyKey(context) => {
-                let relay = self
-                    .unwrap_relay_or_else_fetch(ohttp_relays, Some(context.endpoint()))
-                    .await?;
-                self.post_original_proposal(
-                    context,
-                    persister,
-                    blockchain_client,
-                    relay.to_string(),
-                )
-                .await
+                self.post_original_proposal(context, persister, blockchain_client)
+                    .await
             }
             SendSession::PollingForProposal(context) => {
-                let relay = self
-                    .unwrap_relay_or_else_fetch(ohttp_relays, Some(context.endpoint()))
-                    .await?;
-                self.get_proposed_payjoin_proposal(
-                    context,
-                    persister,
-                    blockchain_client,
-                    relay.to_string(),
-                )
-                .await
+                self.get_proposed_payjoin_proposal(context, persister, blockchain_client)
+                    .await
             }
             SendSession::Closed(SenderSessionOutcome::Success(psbt)) => {
                 self.process_payjoin_proposal(psbt, blockchain_client).await
@@ -816,44 +743,19 @@ impl<'a> PayjoinManager<'a> {
         }
     }
 
-    async fn unwrap_relay_or_else_fetch(
-        &self,
-        ohttp_relays: Vec<url::Url>,
-        directory: Option<impl payjoin::IntoUrl>,
-    ) -> Result<url::Url, Error> {
-        let selected_relay = self
-            .relay_manager
-            .lock()
-            .expect("Lock should not be poisoned")
-            .get_selected_relay();
-        match selected_relay {
-            Some(relay) => Ok(relay),
-            None => {
-                let directory = directory.ok_or_else(|| {
-                    Error::Generic("No directory URL provided and no relay selected".to_string())
-                })?;
-                Ok(
-                    fetch_ohttp_keys(ohttp_relays, directory, self.relay_manager.clone())
-                        .await?
-                        .relay_url,
-                )
-            }
-        }
-    }
-
     async fn post_original_proposal(
         &self,
         sender: Sender<WithReplyKey>,
         persister: &impl SessionPersister<SessionEvent = SenderSessionEvent>,
         blockchain_client: &BlockchainClient,
-        relay: impl payjoin::IntoUrl,
     ) -> Result<Txid, Error> {
-        let (req, ctx) = sender.create_v2_post_request(relay.as_str())?;
-        let response = self.send_payjoin_post_request(req).await?;
+        let (response, ctx) = self
+            .post_via_relay(|relay| sender.create_v2_post_request(relay))
+            .await?;
         let sender = sender
             .process_response(&response.bytes().await?, ctx)
             .save(persister)?;
-        self.get_proposed_payjoin_proposal(sender, persister, blockchain_client, relay)
+        self.get_proposed_payjoin_proposal(sender, persister, blockchain_client)
             .await
     }
 
@@ -862,12 +764,12 @@ impl<'a> PayjoinManager<'a> {
         sender: Sender<PollingForProposal>,
         persister: &impl SessionPersister<SessionEvent = SenderSessionEvent>,
         blockchain_client: &BlockchainClient,
-        relay: impl payjoin::IntoUrl,
     ) -> Result<Txid, Error> {
         let mut sender = sender.clone();
         loop {
-            let (req, ctx) = sender.create_poll_request(relay.as_str())?;
-            let response = self.send_payjoin_post_request(req).await?;
+            let (response, ctx) = self
+                .post_via_relay(|relay| sender.create_poll_request(relay))
+                .await?;
             let processed_response = sender
                 .process_response(&response.bytes().await?, ctx)
                 .save(persister);
@@ -914,13 +816,36 @@ impl<'a> PayjoinManager<'a> {
             .header("Content-Type", req.content_type)
             .body(req.body)
             .send()
-            .await
+            .await?
+            .error_for_status()
+    }
+
+    async fn post_via_relay<F, T, E>(&self, mut build: F) -> Result<(reqwest::Response, T), Error>
+    where
+        F: FnMut(&str) -> Result<(payjoin::Request, T), E>,
+        E: std::fmt::Display,
+    {
+        loop {
+            let relay = self.relay_manager.choose_relay()?;
+            // Build a fresh request for each attempt. Reusing an OHTTP
+            // ciphertext would let relays correlate retransmissions.
+            let (req, context) = build(relay.as_str())
+                .map_err(|e| Error::Generic(format!("Failed to create OHTTP request: {e}")))?;
+
+            match self.send_payjoin_post_request(req).await {
+                Ok(response) => return Ok((response, context)),
+                Err(e) => {
+                    tracing::debug!("Request to OHTTP relay {relay} failed: {e:?}");
+                    self.relay_manager.add_failed_relay(relay);
+                }
+            }
+        }
     }
 
     /// Resume pending payjoin sessions from the database
     pub async fn resume_payjoins(
         &mut self,
-        directory: String,
+        _directory: String,
         ohttp_relays: Vec<String>,
         session_id: Option<i64>,
         blockchain_client: &BlockchainClient,
@@ -951,10 +876,7 @@ impl<'a> PayjoinManager<'a> {
             .map(|s| url::Url::parse(&s))
             .collect::<Result<_, _>>()
             .map_err(|e| Error::Generic(format!("Failed to parse OHTTP URLs: {e}")))?;
-
-        let relay = self
-            .unwrap_relay_or_else_fetch(ohttp_relays, Some(&directory))
-            .await?;
+        self.relay_manager.configure(ohttp_relays)?;
 
         let max_fee_rate = FeeRate::BROADCAST_MIN;
         let total_sessions = recv_session_ids.len() + send_session_ids.len();
@@ -975,7 +897,6 @@ impl<'a> PayjoinManager<'a> {
                         self.proceed_receiver_session(
                             receiver_state,
                             &persister,
-                            relay.as_str(),
                             max_fee_rate,
                             blockchain_client,
                         ),
@@ -1010,12 +931,7 @@ impl<'a> PayjoinManager<'a> {
                     println!("Resuming sender session {}", session_id);
                     match tokio::time::timeout(
                         std::time::Duration::from_secs(30),
-                        self.proceed_sender_session(
-                            sender_state,
-                            &persister,
-                            vec![relay.clone()],
-                            blockchain_client,
-                        ),
+                        self.proceed_sender_session(sender_state, &persister, blockchain_client),
                     )
                     .await
                     {
index 1cc09353023323e4e5d4a20497314a6ab0003657..bc264667d6df8eca4aa69627726b8baf69f6849f 100644 (file)
@@ -3,108 +3,125 @@ use std::sync::{Arc, Mutex};
 
 #[derive(Debug, Clone)]
 pub(crate) struct RelayManager {
-    selected_relay: Option<url::Url>,
-    failed_relays: Vec<url::Url>,
+    relays: Vec<url::Url>,
+    failed_relays: Arc<Mutex<Vec<url::Url>>>,
 }
 
 impl RelayManager {
-    pub fn new() -> Self {
-        RelayManager {
-            selected_relay: None,
-            failed_relays: Vec::new(),
+    pub(crate) fn new() -> Self {
+        Self {
+            relays: Vec::new(),
+            failed_relays: Arc::new(Mutex::new(Vec::new())),
         }
     }
 
-    pub fn set_selected_relay(&mut self, relay: url::Url) {
-        self.selected_relay = Some(relay);
-    }
+    pub(crate) fn configure(&mut self, relays: Vec<url::Url>) -> Result<(), Error> {
+        if relays.is_empty() {
+            return Err(Error::Generic(
+                "At least one valid OHTTP relay must be provided.".into(),
+            ));
+        }
 
-    pub fn get_selected_relay(&self) -> Option<url::Url> {
-        self.selected_relay.clone()
+        self.relays = relays;
+        self.failed_relays
+            .lock()
+            .expect("Lock should not be poisoned")
+            .clear();
+        Ok(())
     }
 
-    pub fn add_failed_relay(&mut self, relay: url::Url) {
-        self.failed_relays.push(relay);
+    pub(crate) fn add_failed_relay(&self, relay: url::Url) {
+        let mut failed_relays = self
+            .failed_relays
+            .lock()
+            .expect("Lock should not be poisoned");
+        if !failed_relays.contains(&relay) {
+            failed_relays.push(relay);
+        }
     }
 
-    pub fn get_failed_relays(&self) -> Vec<url::Url> {
-        self.failed_relays.clone()
-    }
-}
+    pub(crate) fn choose_relay(&self) -> Result<url::Url, Error> {
+        use payjoin::bitcoin::secp256k1::rand::prelude::SliceRandom;
 
-pub(crate) struct ValidatedOhttpKeys {
-    pub(crate) ohttp_keys: payjoin::OhttpKeys,
-    pub(crate) relay_url: url::Url,
-}
-
-pub(crate) async fn fetch_ohttp_keys(
-    relays: Vec<url::Url>,
-    payjoin_directory: impl payjoin::IntoUrl,
-    relay_manager: Arc<Mutex<RelayManager>>,
-) -> Result<ValidatedOhttpKeys, Error> {
-    use payjoin::bitcoin::secp256k1::rand::prelude::SliceRandom;
-
-    loop {
-        let failed_relays = relay_manager
+        let failed_relays = self
+            .failed_relays
             .lock()
-            .expect("Lock should not be poisoned")
-            .get_failed_relays();
-
-        let remaining_relays: Vec<_> = relays
+            .expect("Lock should not be poisoned");
+        let remaining_relays: Vec<_> = self
+            .relays
             .iter()
-            .filter(|r| !failed_relays.contains(r))
+            .filter(|relay| !failed_relays.contains(relay))
             .cloned()
             .collect();
 
-        if remaining_relays.is_empty() {
-            return Err(Error::Generic(
-                "No valid OHTTP relays available".to_string(),
-            ));
-        }
+        remaining_relays
+            .choose(&mut payjoin::bitcoin::key::rand::thread_rng())
+            .cloned()
+            .ok_or_else(|| Error::Generic("No valid OHTTP relays available".to_string()))
+    }
 
-        let selected_relay =
-            match remaining_relays.choose(&mut payjoin::bitcoin::key::rand::thread_rng()) {
-                Some(relay) => relay.clone(),
-                None => {
-                    return Err(Error::Generic(
-                        "Failed to select from remaining relays".to_string(),
-                    ));
-                }
-            };
+    pub(crate) async fn fetch_ohttp_keys(
+        &self,
+        payjoin_directory: impl payjoin::IntoUrl,
+    ) -> Result<payjoin::OhttpKeys, Error> {
+        loop {
+            let selected_relay = self.choose_relay()?;
+            let ohttp_keys =
+                payjoin::io::fetch_ohttp_keys(selected_relay.as_str(), payjoin_directory.as_str())
+                    .await;
 
-        relay_manager
-            .lock()
-            .expect("Lock should not be poisoned")
-            .set_selected_relay(selected_relay.clone());
-
-        let ohttp_keys =
-            payjoin::io::fetch_ohttp_keys(selected_relay.as_str(), payjoin_directory.as_str())
-                .await;
-
-        match ohttp_keys {
-            Ok(keys) => {
-                return Ok(ValidatedOhttpKeys {
-                    ohttp_keys: keys,
-                    relay_url: selected_relay,
-                });
-            }
-            Err(payjoin::io::Error::UnexpectedStatusCode(e)) => {
-                return Err(Error::Generic(format!(
-                    "Unexpected error occurred when fetching OHTTP keys: {}",
-                    e
-                )));
-            }
-            Err(e) => {
-                tracing::debug!(
-                    "Failed to connect to OHTTP relay: {}, {}",
-                    selected_relay,
-                    e
-                );
-                relay_manager
-                    .lock()
-                    .expect("Lock should not be poisoned")
-                    .add_failed_relay(selected_relay);
+            match ohttp_keys {
+                Ok(keys) => return Ok(keys),
+                Err(payjoin::io::Error::UnexpectedStatusCode(e)) => {
+                    return Err(Error::Generic(format!(
+                        "Unexpected error occurred when fetching OHTTP keys: {e}"
+                    )));
+                }
+                Err(e) => {
+                    tracing::debug!(
+                        "Failed to connect to OHTTP relay: {}, {}",
+                        selected_relay,
+                        e
+                    );
+                    self.add_failed_relay(selected_relay);
+                }
             }
         }
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::RelayManager;
+
+    fn relay(url: &str) -> url::Url {
+        url::Url::parse(url).expect("valid relay URL")
+    }
+
+    #[test]
+    fn choose_relay_excludes_failed_relays() {
+        let mut manager = RelayManager::new();
+        let failed = relay("https://failed.example");
+        let available = relay("https://available.example");
+        manager
+            .configure(vec![failed.clone(), available.clone()])
+            .expect("relay configuration");
+
+        manager.add_failed_relay(failed);
+
+        assert_eq!(manager.choose_relay().expect("available relay"), available);
+    }
+
+    #[test]
+    fn choose_relay_fails_when_all_relays_failed() {
+        let mut manager = RelayManager::new();
+        let failed = relay("https://failed.example");
+        manager
+            .configure(vec![failed.clone()])
+            .expect("relay configuration");
+
+        manager.add_failed_relay(failed);
+
+        assert!(manager.choose_relay().is_err());
+    }
+}