]> Untitled Git - bdk/commitdiff
feat(file_store)!: optimize `EntryIter` by reducing syscalls
author志宇 <hello@evanlinjin.me>
Sun, 7 Jan 2024 08:09:03 +0000 (16:09 +0800)
committer志宇 <hello@evanlinjin.me>
Sat, 13 Jan 2024 08:07:43 +0000 (16:07 +0800)
* Wrap file reader with `BufReader`. This reduces calls to `read`.
* Wrap file reader with `CountingReader`. This counts the bytes read by
  the underlying reader. We can rewind without seeking first.

crates/file_store/src/entry_iter.rs

index 770f264f3aa8fad647187a4ed9ecf0c7180f7457..d95a67f8e7f09623a4956b0338995d92a2960858 100644 (file)
@@ -1,12 +1,14 @@
 use bincode::Options;
 use std::{
     fs::File,
-    io::{self, Seek},
+    io::{self, BufReader, Seek},
     marker::PhantomData,
 };
 
 use crate::bincode_options;
 
+type EntryReader<'t> = CountingReader<BufReader<&'t mut File>>;
+
 /// Iterator over entries in a file store.
 ///
 /// Reads and returns an entry each time [`next`] is called. If an error occurs while reading the
@@ -14,7 +16,7 @@ use crate::bincode_options;
 ///
 /// [`next`]: Self::next
 pub struct EntryIter<'t, T> {
-    db_file: Option<&'t mut File>,
+    db_file: Option<EntryReader<'t>>,
 
     /// The file position for the first read of `db_file`.
     start_pos: Option<u64>,
@@ -24,7 +26,7 @@ pub struct EntryIter<'t, T> {
 impl<'t, T> EntryIter<'t, T> {
     pub fn new(start_pos: u64, db_file: &'t mut File) -> Self {
         Self {
-            db_file: Some(db_file),
+            db_file: Some(CountingReader::new(BufReader::new(db_file))),
             start_pos: Some(start_pos),
             types: PhantomData,
         }
@@ -39,32 +41,29 @@ where
 
     fn next(&mut self) -> Option<Self::Item> {
         // closure which reads a single entry starting from `self.pos`
-        let read_one = |f: &mut File, start_pos: Option<u64>| -> Result<Option<T>, IterError> {
-            let pos = match start_pos {
-                Some(pos) => f.seek(io::SeekFrom::Start(pos))?,
-                None => f.stream_position()?,
-            };
-
-            match bincode_options().deserialize_from(&*f) {
-                Ok(changeset) => {
-                    f.stream_position()?;
-                    Ok(Some(changeset))
+        let read_one =
+            |f: &mut EntryReader, start_pos: Option<u64>| -> Result<Option<T>, IterError> {
+                if let Some(pos) = start_pos {
+                    f.seek(io::SeekFrom::Start(pos))?;
                 }
-                Err(e) => {
-                    if let bincode::ErrorKind::Io(inner) = &*e {
-                        if inner.kind() == io::ErrorKind::UnexpectedEof {
-                            let eof = f.seek(io::SeekFrom::End(0))?;
-                            if pos == eof {
+                match bincode_options().deserialize_from(&mut *f) {
+                    Ok(changeset) => {
+                        f.clear_count();
+                        Ok(Some(changeset))
+                    }
+                    Err(e) => {
+                        // allow unexpected EOF if 0 bytes were read
+                        if let bincode::ErrorKind::Io(inner) = &*e {
+                            if inner.kind() == io::ErrorKind::UnexpectedEof && f.count() == 0 {
+                                f.clear_count();
                                 return Ok(None);
                             }
                         }
+                        f.rewind()?;
+                        Err(IterError::Bincode(*e))
                     }
-                    f.seek(io::SeekFrom::Start(pos))?;
-                    Err(IterError::Bincode(*e))
                 }
-            }
-        };
-
+            };
         let result = read_one(self.db_file.as_mut()?, self.start_pos.take());
         if result.is_err() {
             self.db_file = None;
@@ -73,9 +72,13 @@ where
     }
 }
 
-impl From<io::Error> for IterError {
-    fn from(value: io::Error) -> Self {
-        IterError::Io(value)
+impl<'t, T> Drop for EntryIter<'t, T> {
+    fn drop(&mut self) {
+        if let Some(r) = self.db_file.as_mut() {
+            // This syncs the underlying file's offset with the buffer's position. This way, no data
+            // is lost with future reads.
+            let _ = r.stream_position();
+        }
     }
 }
 
@@ -97,4 +100,58 @@ impl core::fmt::Display for IterError {
     }
 }
 
+impl From<io::Error> for IterError {
+    fn from(value: io::Error) -> Self {
+        IterError::Io(value)
+    }
+}
+
 impl std::error::Error for IterError {}
+
+/// A wrapped [`Reader`] which counts total bytes read.
+struct CountingReader<R> {
+    r: R,
+    n: u64,
+}
+
+impl<R> CountingReader<R> {
+    fn new(file: R) -> Self {
+        Self { r: file, n: 0 }
+    }
+
+    /// Counted bytes read.
+    fn count(&self) -> u64 {
+        self.n
+    }
+
+    /// Clear read count.
+    fn clear_count(&mut self) {
+        self.n = 0;
+    }
+}
+
+impl<R: io::Seek> CountingReader<R> {
+    /// Rewind file descriptor offset to before all counted read operations. Then clear the read
+    /// count.
+    fn rewind(&mut self) -> io::Result<u64> {
+        let read = self.r.seek(std::io::SeekFrom::Current(-(self.n as i64)))?;
+        self.n = 0;
+        Ok(read)
+    }
+}
+
+impl<R: io::Read> io::Read for CountingReader<R> {
+    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+        let read = self.r.read(&mut *buf)?;
+        self.n += read as u64;
+        Ok(read)
+    }
+}
+
+impl<R: io::Seek> io::Seek for CountingReader<R> {
+    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
+        let res = self.r.seek(pos);
+        self.n = 0;
+        res
+    }
+}