};
use payjoin::{HpkePublicKey, ImplementationError, UriExt};
use serde_json::{json, to_string_pretty};
-use std::{
- path::PathBuf,
- sync::{Arc, Mutex},
-};
+use std::{path::PathBuf, sync::Arc};
use crate::payjoin::db::{ReceiverPersister, SenderPersister, open_payjoin_db};
-use crate::payjoin::ohttp::{RelayManager, fetch_ohttp_keys};
+use crate::payjoin::ohttp::RelayManager;
pub mod db;
pub mod ohttp;
/// [`PayjoinManager::proceed_receiver_session`].
pub(crate) struct PayjoinManager<'a> {
wallet: &'a mut Wallet,
- relay_manager: Arc<Mutex<RelayManager>>,
+ relay_manager: RelayManager,
db: Arc<crate::payjoin::db::Database>,
}
wallet_name: &str,
) -> Result<Self, Error> {
let db = open_payjoin_db(datadir, wallet_name)?;
- let relay_manager = Arc::new(Mutex::new(RelayManager::new()));
+ let relay_manager = RelayManager::new();
Ok(Self {
wallet,
.collect::<Result<_, _>>()
.map_err(|e| Error::Generic(format!("Failed to parse one or more OHTTP URLs: {e}")))?;
- if ohttp_relays.is_empty() {
- return Err(Error::Generic(
- "At least one valid OHTTP relay must be provided.".into(),
- ));
- }
-
- let ohttp_keys =
- fetch_ohttp_keys(ohttp_relays, &directory, self.relay_manager.clone()).await?;
+ self.relay_manager.configure(ohttp_relays)?;
+ let ohttp_keys = self.relay_manager.fetch_ohttp_keys(&directory).await?;
let persister = crate::payjoin::db::ReceiverPersister::new(self.db.clone())?;
.map(FeeRate::from_sat_per_kwu)
.unwrap_or(FeeRate::BROADCAST_MIN);
- let receiver = payjoin::receive::v2::ReceiverBuilder::new(
- address.address,
- directory,
- ohttp_keys.ohttp_keys,
- )?
- .with_amount(payjoin::bitcoin::Amount::from_sat(amount))
- .with_max_fee_rate(checked_max_fee_rate)
- .build()
- .save(&persister)
- .map_err(|e| {
- Error::Generic(format!(
- "Failed to persister the receiver after initialization: {e}"
- ))
- })?;
+ let receiver =
+ payjoin::receive::v2::ReceiverBuilder::new(address.address, directory, ohttp_keys)?
+ .with_amount(payjoin::bitcoin::Amount::from_sat(amount))
+ .with_max_fee_rate(checked_max_fee_rate)
+ .build()
+ .save(&persister)
+ .map_err(|e| {
+ Error::Generic(format!(
+ "Failed to persister the receiver after initialization: {e}"
+ ))
+ })?;
let pj_uri = receiver.pj_uri();
println!("Request Payjoin by sharing this Payjoin Uri:");
self.proceed_receiver_session(
ReceiveSession::Initialized(receiver.clone()),
&persister,
- ohttp_keys.relay_url.to_string(),
checked_max_fee_rate,
blockchain_client,
)
Error::Generic(format!("Failed to parse one or more OHTTP URLs: {e}"))
})?;
- if ohttp_relays.is_empty() {
- return Err(Error::Generic(
- "At least one valid OHTTP relay must be provided.".into(),
- ));
- }
+ self.relay_manager.configure(ohttp_relays)?;
// Check for existing session with the same receiver pubkey
let receiver_pubkey = v2_param.receiver_pubkey();
let existing_session =
(SendSession::WithReplyKey(sender), persister)
};
- self.proceed_sender_session(
- sender_state,
- &persister,
- ohttp_relays,
- blockchain_client,
- )
- .await?
+ self.proceed_sender_session(sender_state, &persister, blockchain_client)
+ .await?
}
_ => {
unimplemented!("Payjoin version not recognized.");
&mut self,
session: ReceiveSession,
persister: &impl SessionPersister<SessionEvent = ReceiverSessionEvent>,
- relay: impl payjoin::IntoUrl,
max_fee_rate: FeeRate,
blockchain_client: &BlockchainClient,
) -> Result<(), Error> {
match session {
ReceiveSession::Initialized(proposal) => {
- self.read_from_directory(
- proposal,
- persister,
- relay,
- max_fee_rate,
- blockchain_client,
- )
- .await
+ self.read_from_directory(proposal, persister, max_fee_rate, blockchain_client)
+ .await
}
ReceiveSession::UncheckedOriginalPayload(proposal) => {
self.check_proposal(proposal, persister, max_fee_rate, blockchain_client)
&mut self,
receiver: Receiver<Initialized>,
persister: &impl SessionPersister<SessionEvent = ReceiverSessionEvent>,
- relay: impl payjoin::IntoUrl,
max_fee_rate: FeeRate,
blockchain_client: &BlockchainClient,
) -> Result<(), Error> {
let mut current_receiver_typestate = receiver;
let next_receiver_typestate = loop {
- let (req, context) = current_receiver_typestate.create_poll_request(relay.as_str())?;
- let response = self.send_payjoin_post_request(req).await?;
+ let (response, context) = self
+ .post_via_relay(|relay| current_receiver_typestate.create_poll_request(relay))
+ .await?;
let state_transition = current_receiver_typestate
.process_response(response.bytes().await?.to_vec().as_slice(), context)
.save(persister);
persister: &impl SessionPersister<SessionEvent = ReceiverSessionEvent>,
blockchain_client: &BlockchainClient,
) -> Result<(), Error> {
- let (req, ctx) = receiver
- .create_post_request(
- self.unwrap_relay_or_else_fetch(vec![], None::<&str>)
- .await?
- .as_str(),
- )
- .map_err(|e| {
- Error::Generic(format!("Error occurred when creating a post request for sending final Payjoin proposal: {e}"))
- })?;
-
- let res = self.send_payjoin_post_request(req).await?;
+ let (res, ctx) = self
+ .post_via_relay(|relay| receiver.create_post_request(relay))
+ .await?;
let payjoin_psbt = receiver.psbt().clone();
let next_receiver_typestate = receiver
.process_response(&res.bytes().await?, ctx)
receiver: Receiver<HasReplyableError>,
persister: &impl SessionPersister<SessionEvent = ReceiverSessionEvent>,
) -> Result<(), Error> {
- let (err_req, err_ctx) = receiver
- .create_error_request(
- self.unwrap_relay_or_else_fetch(vec![], None::<&str>)
- .await?
- .as_str(),
- )
- .map_err(|e| {
- Error::Generic(format!(
- "Error occurred when creating a receiver error request: {}",
- e
- ))
- })?;
-
- let err_response = match self.send_payjoin_post_request(err_req).await {
- Ok(response) => response,
- Err(e) => {
- return Err(Error::Generic(format!(
- "Failed to post error request: {}",
- e
- )));
- }
- };
+ let (err_response, err_ctx) = self
+ .post_via_relay(|relay| receiver.create_error_request(relay))
+ .await?;
let err_bytes = match err_response.bytes().await {
Ok(bytes) => bytes,
&self,
session: SendSession,
persister: &impl SessionPersister<SessionEvent = SenderSessionEvent>,
- ohttp_relays: Vec<url::Url>,
blockchain_client: &BlockchainClient,
) -> Result<Txid, Error> {
match session {
SendSession::WithReplyKey(context) => {
- let relay = self
- .unwrap_relay_or_else_fetch(ohttp_relays, Some(context.endpoint()))
- .await?;
- self.post_original_proposal(
- context,
- persister,
- blockchain_client,
- relay.to_string(),
- )
- .await
+ self.post_original_proposal(context, persister, blockchain_client)
+ .await
}
SendSession::PollingForProposal(context) => {
- let relay = self
- .unwrap_relay_or_else_fetch(ohttp_relays, Some(context.endpoint()))
- .await?;
- self.get_proposed_payjoin_proposal(
- context,
- persister,
- blockchain_client,
- relay.to_string(),
- )
- .await
+ self.get_proposed_payjoin_proposal(context, persister, blockchain_client)
+ .await
}
SendSession::Closed(SenderSessionOutcome::Success(psbt)) => {
self.process_payjoin_proposal(psbt, blockchain_client).await
}
}
- async fn unwrap_relay_or_else_fetch(
- &self,
- ohttp_relays: Vec<url::Url>,
- directory: Option<impl payjoin::IntoUrl>,
- ) -> Result<url::Url, Error> {
- let selected_relay = self
- .relay_manager
- .lock()
- .expect("Lock should not be poisoned")
- .get_selected_relay();
- match selected_relay {
- Some(relay) => Ok(relay),
- None => {
- let directory = directory.ok_or_else(|| {
- Error::Generic("No directory URL provided and no relay selected".to_string())
- })?;
- Ok(
- fetch_ohttp_keys(ohttp_relays, directory, self.relay_manager.clone())
- .await?
- .relay_url,
- )
- }
- }
- }
-
async fn post_original_proposal(
&self,
sender: Sender<WithReplyKey>,
persister: &impl SessionPersister<SessionEvent = SenderSessionEvent>,
blockchain_client: &BlockchainClient,
- relay: impl payjoin::IntoUrl,
) -> Result<Txid, Error> {
- let (req, ctx) = sender.create_v2_post_request(relay.as_str())?;
- let response = self.send_payjoin_post_request(req).await?;
+ let (response, ctx) = self
+ .post_via_relay(|relay| sender.create_v2_post_request(relay))
+ .await?;
let sender = sender
.process_response(&response.bytes().await?, ctx)
.save(persister)?;
- self.get_proposed_payjoin_proposal(sender, persister, blockchain_client, relay)
+ self.get_proposed_payjoin_proposal(sender, persister, blockchain_client)
.await
}
sender: Sender<PollingForProposal>,
persister: &impl SessionPersister<SessionEvent = SenderSessionEvent>,
blockchain_client: &BlockchainClient,
- relay: impl payjoin::IntoUrl,
) -> Result<Txid, Error> {
let mut sender = sender.clone();
loop {
- let (req, ctx) = sender.create_poll_request(relay.as_str())?;
- let response = self.send_payjoin_post_request(req).await?;
+ let (response, ctx) = self
+ .post_via_relay(|relay| sender.create_poll_request(relay))
+ .await?;
let processed_response = sender
.process_response(&response.bytes().await?, ctx)
.save(persister);
.header("Content-Type", req.content_type)
.body(req.body)
.send()
- .await
+ .await?
+ .error_for_status()
+ }
+
+ async fn post_via_relay<F, T, E>(&self, mut build: F) -> Result<(reqwest::Response, T), Error>
+ where
+ F: FnMut(&str) -> Result<(payjoin::Request, T), E>,
+ E: std::fmt::Display,
+ {
+ loop {
+ let relay = self.relay_manager.choose_relay()?;
+ // Build a fresh request for each attempt. Reusing an OHTTP
+ // ciphertext would let relays correlate retransmissions.
+ let (req, context) = build(relay.as_str())
+ .map_err(|e| Error::Generic(format!("Failed to create OHTTP request: {e}")))?;
+
+ match self.send_payjoin_post_request(req).await {
+ Ok(response) => return Ok((response, context)),
+ Err(e) => {
+ tracing::debug!("Request to OHTTP relay {relay} failed: {e:?}");
+ self.relay_manager.add_failed_relay(relay);
+ }
+ }
+ }
}
/// Resume pending payjoin sessions from the database
pub async fn resume_payjoins(
&mut self,
- directory: String,
+ _directory: String,
ohttp_relays: Vec<String>,
session_id: Option<i64>,
blockchain_client: &BlockchainClient,
.map(|s| url::Url::parse(&s))
.collect::<Result<_, _>>()
.map_err(|e| Error::Generic(format!("Failed to parse OHTTP URLs: {e}")))?;
-
- let relay = self
- .unwrap_relay_or_else_fetch(ohttp_relays, Some(&directory))
- .await?;
+ self.relay_manager.configure(ohttp_relays)?;
let max_fee_rate = FeeRate::BROADCAST_MIN;
let total_sessions = recv_session_ids.len() + send_session_ids.len();
self.proceed_receiver_session(
receiver_state,
&persister,
- relay.as_str(),
max_fee_rate,
blockchain_client,
),
println!("Resuming sender session {}", session_id);
match tokio::time::timeout(
std::time::Duration::from_secs(30),
- self.proceed_sender_session(
- sender_state,
- &persister,
- vec![relay.clone()],
- blockchain_client,
- ),
+ self.proceed_sender_session(sender_state, &persister, blockchain_client),
)
.await
{
#[derive(Debug, Clone)]
pub(crate) struct RelayManager {
- selected_relay: Option<url::Url>,
- failed_relays: Vec<url::Url>,
+ relays: Vec<url::Url>,
+ failed_relays: Arc<Mutex<Vec<url::Url>>>,
}
impl RelayManager {
- pub fn new() -> Self {
- RelayManager {
- selected_relay: None,
- failed_relays: Vec::new(),
+ pub(crate) fn new() -> Self {
+ Self {
+ relays: Vec::new(),
+ failed_relays: Arc::new(Mutex::new(Vec::new())),
}
}
- pub fn set_selected_relay(&mut self, relay: url::Url) {
- self.selected_relay = Some(relay);
- }
+ pub(crate) fn configure(&mut self, relays: Vec<url::Url>) -> Result<(), Error> {
+ if relays.is_empty() {
+ return Err(Error::Generic(
+ "At least one valid OHTTP relay must be provided.".into(),
+ ));
+ }
- pub fn get_selected_relay(&self) -> Option<url::Url> {
- self.selected_relay.clone()
+ self.relays = relays;
+ self.failed_relays
+ .lock()
+ .expect("Lock should not be poisoned")
+ .clear();
+ Ok(())
}
- pub fn add_failed_relay(&mut self, relay: url::Url) {
- self.failed_relays.push(relay);
+ pub(crate) fn add_failed_relay(&self, relay: url::Url) {
+ let mut failed_relays = self
+ .failed_relays
+ .lock()
+ .expect("Lock should not be poisoned");
+ if !failed_relays.contains(&relay) {
+ failed_relays.push(relay);
+ }
}
- pub fn get_failed_relays(&self) -> Vec<url::Url> {
- self.failed_relays.clone()
- }
-}
+ pub(crate) fn choose_relay(&self) -> Result<url::Url, Error> {
+ use payjoin::bitcoin::secp256k1::rand::prelude::SliceRandom;
-pub(crate) struct ValidatedOhttpKeys {
- pub(crate) ohttp_keys: payjoin::OhttpKeys,
- pub(crate) relay_url: url::Url,
-}
-
-pub(crate) async fn fetch_ohttp_keys(
- relays: Vec<url::Url>,
- payjoin_directory: impl payjoin::IntoUrl,
- relay_manager: Arc<Mutex<RelayManager>>,
-) -> Result<ValidatedOhttpKeys, Error> {
- use payjoin::bitcoin::secp256k1::rand::prelude::SliceRandom;
-
- loop {
- let failed_relays = relay_manager
+ let failed_relays = self
+ .failed_relays
.lock()
- .expect("Lock should not be poisoned")
- .get_failed_relays();
-
- let remaining_relays: Vec<_> = relays
+ .expect("Lock should not be poisoned");
+ let remaining_relays: Vec<_> = self
+ .relays
.iter()
- .filter(|r| !failed_relays.contains(r))
+ .filter(|relay| !failed_relays.contains(relay))
.cloned()
.collect();
- if remaining_relays.is_empty() {
- return Err(Error::Generic(
- "No valid OHTTP relays available".to_string(),
- ));
- }
+ remaining_relays
+ .choose(&mut payjoin::bitcoin::key::rand::thread_rng())
+ .cloned()
+ .ok_or_else(|| Error::Generic("No valid OHTTP relays available".to_string()))
+ }
- let selected_relay =
- match remaining_relays.choose(&mut payjoin::bitcoin::key::rand::thread_rng()) {
- Some(relay) => relay.clone(),
- None => {
- return Err(Error::Generic(
- "Failed to select from remaining relays".to_string(),
- ));
- }
- };
+ pub(crate) async fn fetch_ohttp_keys(
+ &self,
+ payjoin_directory: impl payjoin::IntoUrl,
+ ) -> Result<payjoin::OhttpKeys, Error> {
+ loop {
+ let selected_relay = self.choose_relay()?;
+ let ohttp_keys =
+ payjoin::io::fetch_ohttp_keys(selected_relay.as_str(), payjoin_directory.as_str())
+ .await;
- relay_manager
- .lock()
- .expect("Lock should not be poisoned")
- .set_selected_relay(selected_relay.clone());
-
- let ohttp_keys =
- payjoin::io::fetch_ohttp_keys(selected_relay.as_str(), payjoin_directory.as_str())
- .await;
-
- match ohttp_keys {
- Ok(keys) => {
- return Ok(ValidatedOhttpKeys {
- ohttp_keys: keys,
- relay_url: selected_relay,
- });
- }
- Err(payjoin::io::Error::UnexpectedStatusCode(e)) => {
- return Err(Error::Generic(format!(
- "Unexpected error occurred when fetching OHTTP keys: {}",
- e
- )));
- }
- Err(e) => {
- tracing::debug!(
- "Failed to connect to OHTTP relay: {}, {}",
- selected_relay,
- e
- );
- relay_manager
- .lock()
- .expect("Lock should not be poisoned")
- .add_failed_relay(selected_relay);
+ match ohttp_keys {
+ Ok(keys) => return Ok(keys),
+ Err(payjoin::io::Error::UnexpectedStatusCode(e)) => {
+ return Err(Error::Generic(format!(
+ "Unexpected error occurred when fetching OHTTP keys: {e}"
+ )));
+ }
+ Err(e) => {
+ tracing::debug!(
+ "Failed to connect to OHTTP relay: {}, {}",
+ selected_relay,
+ e
+ );
+ self.add_failed_relay(selected_relay);
+ }
}
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::RelayManager;
+
+ fn relay(url: &str) -> url::Url {
+ url::Url::parse(url).expect("valid relay URL")
+ }
+
+ #[test]
+ fn choose_relay_excludes_failed_relays() {
+ let mut manager = RelayManager::new();
+ let failed = relay("https://failed.example");
+ let available = relay("https://available.example");
+ manager
+ .configure(vec![failed.clone(), available.clone()])
+ .expect("relay configuration");
+
+ manager.add_failed_relay(failed);
+
+ assert_eq!(manager.choose_relay().expect("available relay"), available);
+ }
+
+ #[test]
+ fn choose_relay_fails_when_all_relays_failed() {
+ let mut manager = RelayManager::new();
+ let failed = relay("https://failed.example");
+ manager
+ .configure(vec![failed.clone()])
+ .expect("relay configuration");
+
+ manager.add_failed_relay(failed);
+
+ assert!(manager.choose_relay().is_err());
+ }
+}