Source code for seed_vault.ui.components.waveform

from typing import List, Dict, Union
from obspy import Stream, Trace
from obspy import UTCDateTime
import threading
from seed_vault.enums.config import WorkflowType
from seed_vault.models.config import SeismoLoaderSettings
from seed_vault.service.seismoloader import run_event
from obspy.clients.fdsn import Client
from obspy.taup import TauPyModel
from seed_vault.ui.components.display_log import ConsoleDisplay
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from html import escape
from seed_vault.ui.components.continuous_waveform import ContinuousComponents
from seed_vault.service.utils import check_client_services
from copy import deepcopy
from seed_vault.ui.app_pages.helpers.common import save_filter
import time
import sys
import queue


query_thread = None
stop_event = threading.Event()
log_queue = queue.Queue()


if "query_done" not in st.session_state:
    st.session_state["query_done"] = False
if "trigger_rerun" not in st.session_state:
    st.session_state["trigger_rerun"] = False
if "log_entries" not in st.session_state:
    st.session_state["log_entries"] = []

[docs] def get_tele_filter(tr): """Calculate appropriate filter band for teleseismic data based on distance. This function determines the optimal frequency band for filtering teleseismic data based on the distance from the source and the sensor type. Args: tr (Trace): ObsPy Trace object containing waveform data. Returns: tuple: A tuple of (f0, f1) where: - f0 (float): Lower frequency bound in Hz - f1 (float): Upper frequency bound in Hz Returns (0, 0) for non-seismic sensors. Note: The filter bands are optimized for different distance ranges: - < 50 km: 2.0-15 Hz - 50-100 km: 1.8-12 Hz - 100-250 km: 1.7-10 Hz - 250-500 km: 1.6-8 Hz - 500-1000 km: 1.5-6 Hz - 1000-2500 km: 1.4-5 Hz - 2500-5000 km: 1.2-4 Hz - 5000-10000 km: 1.0-3 Hz - > 10000 km: 0.7-2 Hz """ distance_km = tr.stats.distance_km nyq = tr.stats.sampling_rate/2 - 0.1 senstype = tr.stats.channel[1] if senstype not in ['H','N']: return 0,0 # flagged elsewhere if distance_km < 50: f0,f1 = 2.0,15 elif distance_km < 100: f0,f1 = 1.8,12 elif distance_km < 250: f0,f1 = 1.7,10 elif distance_km < 500: f0,f1 = 1.6,8 elif distance_km < 1000: f0,f1 = 1.5,6 elif distance_km < 2500: f0,f1 = 1.4,5 elif distance_km < 5000: f0,f1 = 1.2,4 elif distance_km < 10000: f0,f1 = 1.0,3 else: f0,f1 = 0.7,2 return min(f0,nyq),min(f1,nyq)
[docs] class WaveformFilterMenu: settings: SeismoLoaderSettings network_filter: str station_filter: str channel_filter: str available_channels: List[str] display_limit: int def __init__(self, settings: SeismoLoaderSettings): """Initialize the WaveformFilterMenu. Args: settings (SeismoLoaderSettings): Configuration settings for seismic data processing. """ self.settings = settings self.old_settings = deepcopy(settings) # Track previous state self.network_filter = "All networks" self.station_filter = "All stations" self.channel_filter = "All channels" self.available_channels = ["All channels"] self.display_limit = 50 # Track previous filter state self.old_filter_state = { 'network_filter': self.network_filter, 'station_filter': self.station_filter, 'channel_filter': self.channel_filter, 'display_limit': self.display_limit }
[docs] def refresh_filters(self): """Check for changes in filter settings and trigger UI updates. This method compares current filter settings with previous state and triggers a UI refresh if changes are detected. It also handles saving of filter settings. Note: The method uses Streamlit's rerun mechanism to update the UI when changes are detected. """ current_state = { 'network_filter': self.network_filter, 'station_filter': self.station_filter, 'channel_filter': self.channel_filter, 'display_limit': self.display_limit } # Check if filter state changed if current_state != self.old_filter_state: self.old_filter_state = current_state.copy() st.rerun() # Check if settings changed changes = self.settings.has_changed(self.old_settings) if changes.get('has_changed', False): self.old_settings = deepcopy(self.settings) save_filter(self.settings) st.rerun()
[docs] def update_available_channels(self, stream: Stream): """Update the list of available channels based on the current stream. This method extracts unique channel codes from the provided stream and updates the available_channels list. Args: stream (Stream): ObsPy Stream object containing waveform data. Note: The method handles different types of stream objects and ensures "All channels" remains as the first option in the list. """ if not stream: self.available_channels = ["All channels"] return channels = set() # Handle different types of stream objects if isinstance(stream, Stream): # Case 1: ObsPy Stream object for tr in stream: if hasattr(tr.stats, 'channel'): channels.add(tr.stats.channel) elif isinstance(stream, list): # Case 2: List of traces or Stream objects for item in stream: if isinstance(item, Trace): # Individual trace if hasattr(item.stats, 'channel'): channels.add(item.stats.channel) elif isinstance(item, Stream): # Stream object in a list for tr in item: if hasattr(tr.stats, 'channel'): channels.add(tr.stats.channel) else: # Try to handle as a generic object with stats.channel try: if hasattr(item, 'stats') and hasattr(item.stats, 'channel'): channels.add(item.stats.channel) except: pass # If we found any channels, update the available_channels list # Ensure "All channels" is always at the top by separating it from the sorted list if channels: sorted_channels = sorted(list(channels)) self.available_channels = ["All channels"] + sorted_channels else: self.available_channels = ["All channels"] # Reset channel filter if current selection is invalid if self.channel_filter not in self.available_channels: self.channel_filter = "All channels"
[docs] def render(self, stream=None): """Render the waveform filter menu interface. This method creates the UI for waveform filtering and control, including: - Network, station, and channel filters - Display limit controls - Status information - Reset functionality Args: stream (Stream, optional): Current waveform stream to filter. If None, only basic controls are shown. Note: The interface is organized in expandable sections for better user experience and space management. """ st.sidebar.title("Waveform Controls") # Step 1: Data Retrieval Settings with st.sidebar.expander("Step 1: Data Source", expanded=True): st.subheader("🔍 Filter Events Around Individual Stations") cc1, cc2 = st.columns([1, 1]) with cc1: min_radius = st.number_input("Minimum radius (degree)", value=self.settings.event.min_radius or 0.0, step=0.5, min_value=0.0, max_value=180.0) if min_radius != self.settings.event.min_radius: self.settings.event.min_radius = min_radius self.refresh_filters() with cc2: max_radius = st.number_input("Maximum radius (degree)", value=self.settings.event.max_radius or 90.0, step=0.5, min_value=0.0, max_value=180.0) if max_radius != self.settings.event.max_radius: self.settings.event.max_radius = max_radius self.refresh_filters() st.subheader("🔍 Time Window") # Update time window settings with immediate refresh before_p = st.number_input( "Start (secs before P arrival):", value=self.settings.event.before_p_sec or 20, step = 5, help="Time window before P arrival", key="before_p_input" ) if before_p != self.settings.event.before_p_sec: self.settings.event.before_p_sec = before_p self.refresh_filters() after_p = st.number_input( "End (secs after P arrival):", value=self.settings.event.after_p_sec or 100, step = 5, help="Time window after P arrival", key="after_p_input" ) if after_p != self.settings.event.after_p_sec: self.settings.event.after_p_sec = after_p self.refresh_filters() # Client selection with immediate refresh client_options = list(self.settings.client_url_mapping.get_clients()) selected_client = st.selectbox( 'Choose a client:', client_options, index=client_options.index(self.settings.waveform.client), key="waveform_client_select" ) if selected_client != self.settings.waveform.client: self.settings.waveform.client = selected_client self.refresh_filters() # Check services for selected client services = check_client_services(self.settings.waveform.client) if not services['dataselect']: st.warning(f"⚠️ Warning: Selected client '{self.settings.waveform.client}' does not support WAVEFORM service. Please choose another client.") # Add Download Preferences section st.subheader("📊 Download Preferences") # Channel Priority Input channel_pref = st.text_input( "Channel Priority", value=self.settings.waveform.channel_pref, help="Order of preferred channels (e.g., HH,BH,EH). Only the first existing channel in this list will be downloaded.", key="channel_pref" ) # Validate and update channel preferences if channel_pref: # Remove spaces and convert to uppercase channel_pref = channel_pref.replace(" ", "").upper() # Basic validation channel_codes = channel_pref.split(",") is_valid = all(len(code) == 2 for code in channel_codes) if is_valid: self.settings.waveform.channel_pref = channel_pref else: st.error("Invalid channel format. Each channel code should be 2 characters (e.g., HH,BH,EH)") # Location Priority Input location_pref = st.text_input( "Location Priority", value=self.settings.waveform.location_pref, help="Order of preferred location codes (e.g., 00,--,10,20). Only the first existing location code in this list will be downloaded.. Use -- or '' for blank location.", key="location_pref" ) # Validate and update location preferences if location_pref: # Remove spaces location_pref = location_pref.replace(" ", "") # Basic validation location_codes = location_pref.split(",") is_valid = all(len(code) <= 2 for code in location_codes) if is_valid: self.settings.waveform.location_pref = location_pref else: st.error("Invalid location format. Each location code should be 0-2 characters (e.g., 00,--,10,20)") if stream is not None: # Get network codes and sort them, ensuring "All networks" is at the top network_codes = list(set([inv.code for inv in self.settings.station.selected_invs])) network_codes.sort() # Sort alphabetically networks = ["All networks"] + network_codes # Ensure "All networks" is at the top selected_network = st.selectbox( "Network:", networks, index=networks.index(self.network_filter), help="Filter by network", key="network_filter_select" ) if selected_network != self.network_filter: self.network_filter = selected_network self.refresh_filters() # Station filter with immediate refresh # Get station codes and sort them, ensuring "All stations" is at the top station_codes = [] for inv in self.settings.station.selected_invs: station_codes.extend([sta.code for sta in inv]) station_codes = list(dict.fromkeys(station_codes)) # Remove duplicates station_codes.sort() # Sort alphabetically stations = ["All stations"] + station_codes # Ensure "All stations" is at the top selected_station = st.selectbox( "Station:", stations, index=stations.index(self.station_filter), help="Filter by station", key="station_filter_select" ) if selected_station != self.station_filter: self.station_filter = selected_station self.refresh_filters() # Channel filter with immediate refresh self.channel_filter = st.selectbox( "Channel:", options=self.available_channels, index=self.available_channels.index(self.channel_filter), help="Filter by channel", key="channel_filter_select" ) if self.channel_filter != self.old_filter_state['channel_filter']: self.old_filter_state['channel_filter'] = self.channel_filter self.refresh_filters() st.subheader("📊 Display Options") display_limit = st.selectbox( "Waveforms per page:", options=[10, 25, 50], index=[10, 25, 50].index(self.display_limit), key="waveform_display_limit", help="Number of waveforms to show per page" ) if display_limit != self.display_limit: self.display_limit = display_limit self.refresh_filters() # Add status information if stream: st.sidebar.info(f"Total waveforms: {len(stream)}") # Add reset filters button if st.sidebar.button("Reset Filters"): self.network_filter = "All networks" self.station_filter = "All stations" self.channel_filter = "All channels" self.display_limit = 50 self.refresh_filters()
[docs] class WaveformDisplay: """A component for displaying and managing waveform data visualization. This class handles the display of seismic waveform data, including both event-based and station-based views, with support for filtering and pagination. Attributes: settings (SeismoLoaderSettings): Configuration settings for seismic data processing. filter_menu (WaveformFilterMenu): Menu component for filtering waveforms. client (Client): FDSN client for waveform data retrieval. ttmodel (TauPyModel): Travel-time model for seismic phases. stream (List[Stream]): List of waveform streams. missing_data (Dict): Dictionary tracking missing data. console (ConsoleDisplay): Console for logging output. """ def __init__(self, settings: SeismoLoaderSettings, filter_menu: WaveformFilterMenu): """Initialize the WaveformDisplay component. Args: settings (SeismoLoaderSettings): Configuration settings for seismic data processing. filter_menu (WaveformFilterMenu): Menu component for filtering waveforms. """ self.settings = settings self.filter_menu = filter_menu try: self.client = Client(self.settings.waveform.client) except ValueError as e: st.error(f"Error: {str(e)} Waveform client is set to {self.settings.waveform.client}, which seems does not exists. Please navigate to the settings page and use the Clients tab to add the client or fix the stored config.cfg file.") self.ttmodel = TauPyModel("iasp91") self.stream = [] self.missing_data = {} self.console = ConsoleDisplay() # Add console display
[docs] def apply_filters(self, stream) -> Stream: """Apply filters to the waveform stream based on user selections. Args: stream (Stream): Input ObsPy Stream object to filter. Returns: Stream: Filtered stream containing only traces matching the selected network, station, and channel filters. """ filtered_stream = Stream() # Handle case where stream is a list of traces if isinstance(stream, list): stream = Stream(traces=stream) if not stream: return filtered_stream for tr in stream: try: if (self.filter_menu.network_filter == "All networks" or tr.stats.network == self.filter_menu.network_filter) and \ (self.filter_menu.station_filter == "All stations" or tr.stats.station == self.filter_menu.station_filter) and \ (self.filter_menu.channel_filter == "All channels" or tr.stats.channel == self.filter_menu.channel_filter): filtered_stream += tr except AttributeError as e: continue return filtered_stream
[docs] def fetch_data(self): """Fetch waveform data in a background thread with logging. This method sets up a custom logging system, retrieves waveform data, and handles any errors or cancellations during the process. Note: The method updates the session state with processing status and logs. """ # Custom stdout/stderr handler that writes to both the original traces and our queue class QueueLogger: def __init__(self, original_stream, queue): self.original_stream = original_stream self.queue = queue self.buffer = "" def write(self, text): self.original_stream.write(text) self.buffer += text # Only flush when buffer gets large enough if len(self.buffer) > 80: # Buffer getting too large, flush it self.queue.put(self.buffer) self.buffer = "" """ def write(self, text): self.original_stream.write(text) self.buffer += text if '\n' in text: ### this may be bad lines = self.buffer.split('\n') for line in lines[:-1]: # All complete lines if line: # Skip empty lines self.queue.put(line) self.buffer = lines[-1] # Keep any partial line # Also handle case where no newline but we have content elif text and len(self.buffer) > 80: # Buffer getting long, flush it self.queue.put(self.buffer) self.buffer = "" """ def flush(self): self.original_stream.flush() if self.buffer: # Flush any remaining content in buffer self.queue.put(self.buffer) self.buffer = "" # Set up queue loggers original_stdout = sys.stdout original_stderr = sys.stderr sys.stdout = QueueLogger(original_stdout, log_queue) sys.stderr = QueueLogger(original_stderr, log_queue) try: print("Starting waveform download process...") stream_and_missing = run_event(self.settings, stop_event) if stream_and_missing: self.stream, self.missing_data = stream_and_missing success = True print("Download completed successfully.") # If stopped via the cancel button, reset it continue to plotting as normal stop_event.clear() else: success = False if stop_event.is_set(): print("Download was cancelled by user.") st.session_state["download_cancelled"] = True else: print("Download failed.") except Exception as e: success = False print(f"Error: {str(e)}") finally: # Flush any remaining content sys.stdout.flush() sys.stderr.flush() # Ensure stop event is cleared stop_event.clear() # Restore original stdout/stderr sys.stdout = original_stdout sys.stderr = original_stderr st.session_state.update({ "query_done": True, "is_downloading": False, "trigger_rerun": True })
[docs] def retrieve_waveforms(self): """Initiate waveform retrieval in a background thread. This method starts a new thread for waveform data retrieval and updates the UI state accordingly. Note: The method handles thread creation, state management, and UI updates. """ if not self.settings.event.selected_catalogs or not self.settings.station.selected_invs: st.warning("Please select events and stations before downloading waveforms.") return stop_event.clear() # Reset cancellation flag st.session_state["query_thread"] = threading.Thread(target=self.fetch_data, daemon=True) st.session_state["query_thread"].start() st.session_state.update({ "is_downloading": True, "query_done": False, "polling_active": True, "download_cancelled": False # Initialize cancellation flag }) st.rerun()
def _get_trace_color(self, tr) -> str: """Get color for a trace based on its channel component. Args: tr (Trace): ObsPy Trace object. Returns: str: Color code for the trace based on its component: - 'Z': black - 'N' or '1': blue - 'E' or '2': green - others: gray - non-seismic sensors: tomato """ # Extract last character of channel code component = tr.stats.channel[-1].upper() sensortype = tr.stats.channel[1].upper() if sensortype not in ['H','N']: return 'tomato' # Standard color scheme for components if component == 'Z': return 'black' elif component in ['N', '1']: return 'blue' elif component in ['E', '2']: return 'green' else: return 'gray' def _calculate_figure_dimensions(self, num_traces: int) -> tuple: """Calculate figure dimensions based on number of traces. Args: num_traces (int): Number of traces to display. Returns: tuple: A tuple of (width, height) in inches for the figure. Width is fixed at 12 inches, height is calculated based on number of traces with a minimum of 4 inches. """ width = 12 # Slightly wider for better readability height_per_trace = 1.0 # Reduced slightly to fit more traces # Remove maximum height limit, keep minimum total_height = num_traces * height_per_trace + 0.5 total_height = max(4, total_height) # Only keep minimum height limit return (width, total_height)
[docs] def plot_event_view(self, event, stream: Stream, page: int, num_pages: int): """Plot event view with proper time alignment and improved layout. Args: event: Event object containing event information. stream (Stream): ObsPy Stream object containing waveform data. page (int): Current page number for pagination. num_pages (int): Total number of pages. Returns: Figure: Matplotlib figure object containing the plot. """ if not stream: return # Sort traces by distance (via starttime) stream.traces.sort(key=lambda x: x.stats.starttime) # Get current page's traces start_idx = page * self.filter_menu.display_limit end_idx = start_idx + self.filter_menu.display_limit current_stream = Stream(traces=stream.traces[start_idx:end_idx]) # Create figure with standardized dimensions num_traces = len(current_stream) width, height = self._calculate_figure_dimensions(num_traces) fig = plt.figure(figsize=(width, height)) # Use GridSpec with standardized spacing gs = plt.GridSpec(num_traces, 1, height_ratios=[1] * num_traces, hspace=0.05, top=0.99, # Adjusted from 0.97 to remove title space bottom=0.08, left=0.1, right=0.9) axes = [plt.subplot(gs[i]) for i in range(num_traces)] # Process each trace for i, tr in enumerate(current_stream): #print("DEBUG plotted trace:",tr) ax = axes[i] # Calculate and add an appropriate filter for plotting filter_min,filter_max = get_tele_filter(tr) if filter_min < filter_max: tr.stats.filterband = (filter_min,filter_max) if hasattr(tr.stats, 'p_arrival'): p_time = UTCDateTime(tr.stats.p_arrival) before_p = self.settings.event.before_p_sec after_p = self.settings.event.after_p_sec # Trim trace to window around P window_start = p_time - before_p window_end = p_time + after_p tr_windowed = tr.slice(window_start, window_end) # Pre-process and apply a bandpass filter if tr_windowed.stats.sampling_rate/2 > filter_min and filter_min<filter_max: tr_windowed.detrend() tr_windowed.taper(.005) tr_windowed.filter('bandpass',freqmin=filter_min,freqmax=filter_max,zerophase=True) # Calculate times relative to P arrival times = np.arange(tr_windowed.stats.npts) * tr_windowed.stats.delta relative_times = times - before_p # This makes P arrival at t=0 # Plot the trace ax.plot(relative_times, tr_windowed.data, '-', color=self._get_trace_color(tr), linewidth=0.8) # Add P arrival line (now at x=0) ax.axvline(x=0, color='red', linewidth=1, linestyle='-', alpha=0.8) # Format station label with distance station_info = f"{tr.stats.network}.{tr.stats.station}.{tr.stats.location or ''}.{tr.stats.channel}" if hasattr(tr.stats, 'distance_km'): station_info += f" {tr.stats.distance_km:.1f} km" #adding event mag and region temporarily for debugging if hasattr(tr.stats, 'event_magnitude'): station_info += f", M{tr.stats.event_magnitude:.1f}" if hasattr(tr.stats, 'event_region'): station_info += f", {tr.stats.event_region}" if hasattr(tr.stats, 'filterband'): station_info += f", {tr.stats.filterband[0]}-{tr.stats.filterband[1]}Hz" # Position label inside plot ax.text(0.02, 0.95, station_info, transform=ax.transAxes, verticalalignment='top', horizontalalignment='left', fontsize=7, bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=1)) # Set consistent x-axis limits ax.set_xlim(-before_p, after_p) # Remove y-axis ticks and labels ax.set_yticks([]) # Only show x-axis labels for bottom subplot if i < num_traces - 1: ax.set_xticklabels([]) else: ax.set_xlabel('Time relative to P (seconds)') # Add subtle grid ax.grid(True, alpha=0.2) # Update box styling to show all borders for spine in ax.spines.values(): spine.set_visible(True) spine.set_linewidth(0.5) # Add padding to the plot ax.margins(x=0.05) # Increased padding to 5% on left and right # Adjust layout plt.subplots_adjust(left=0.1, right=0.9, top=0.97, bottom=0.1) return fig
[docs] def plot_station_view(self, station_code: str, stream: Stream, page: int, num_pages: int): """Plot station view with event information. Args: station_code (str): Code of the station to display. stream (Stream): ObsPy Stream object containing waveform data. page (int): Current page number for pagination. num_pages (int): Total number of pages. Returns: Figure: Matplotlib figure object containing the plot. """ if not stream: return # Sort traces by distance for tr in stream: if not hasattr(tr.stats, 'distance_km') or not tr.stats.distance_km: tr.stats.distance_km = 99999 stream = Stream(sorted(stream, key=lambda tr: tr.stats.distance_km)) #TODO users may prefer to sort by OT # Get current page's traces start_idx = page * self.filter_menu.display_limit end_idx = start_idx + self.filter_menu.display_limit current_stream = Stream(traces=stream.traces[start_idx:end_idx]) # Calculate standardized dimensions width, height = self._calculate_figure_dimensions(len(current_stream)) # Create figure with standardized dimensions fig = plt.figure(figsize=(width, height)) # Use GridSpec with standardized spacing gs = plt.GridSpec(len(current_stream), 1, height_ratios=[1] * len(current_stream), hspace=0.05, top=0.97, bottom=0.08, left=0.1, right=0.9) axes = [plt.subplot(gs[i]) for i in range(len(current_stream))] # Process each trace for i, tr in enumerate(current_stream): ax = axes[i] # Calculate and add an appropriate filter for plotting filter_min,filter_max = get_tele_filter(tr) if filter_min < filter_max: tr.stats.filterband = (filter_min,filter_max) if hasattr(tr.stats, 'p_arrival'): p_time = UTCDateTime(tr.stats.p_arrival) before_p = self.settings.event.before_p_sec after_p = self.settings.event.after_p_sec # Trim trace to window around P window_start = p_time - before_p window_end = p_time + after_p tr_windowed = tr.slice(window_start, window_end) # Pre-process and apply a bandpass filter if tr_windowed.stats.sampling_rate/2 > filter_min and filter_min<filter_max: tr_windowed.detrend() tr_windowed.taper(.005) tr_windowed.filter('bandpass',freqmin=filter_min,freqmax=filter_max,zerophase=True) # Calculate times relative to P arrival times = np.arange(tr_windowed.stats.npts) * tr_windowed.stats.delta relative_times = times - before_p # This makes P arrival at t=0 # Plot the trace ax.plot(relative_times, tr_windowed.data, '-', color=self._get_trace_color(tr), linewidth=0.8) # Add P arrival line (now at x=0) ax.axvline(x=0, color='red', linewidth=1, linestyle='-', alpha=0.8) # Format station label with distance, magnitude, and region station_info = f"{tr.stats.network}.{tr.stats.station}.{tr.stats.location or ''}.{tr.stats.channel}" event_info = [] # TODO: Add distance, magnitude, and region to station_info if hasattr(tr.stats, 'distance_km'): event_info.append(f"{tr.stats.distance_km:.1f} km") if hasattr(tr.stats, 'event_time'): event_info.append(f"OT:{str(tr.stats.event_time)[0:19]}") if hasattr(tr.stats, 'event_magnitude'): event_info.append(f"M{tr.stats.event_magnitude:.1f}") if hasattr(tr.stats, 'event_region'): event_info.append(tr.stats.event_region) if hasattr(tr.stats, 'filterband'): event_info.append(f"{tr.stats.filterband[0]}-{tr.stats.filterband[1]}Hz") # Combine all information with proper formatting label = f"{station_info} - {', '.join(event_info)}" # Position label inside plot ax.text(0.02, 0.95, label, transform=ax.transAxes, verticalalignment='top', horizontalalignment='left', fontsize=7, bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=1)) # Set consistent x-axis limits ax.set_xlim(-before_p, after_p) # Remove y-axis ticks and labels ax.set_yticks([]) # Only show x-axis labels for bottom subplot if i < len(current_stream) - 1: ax.set_xticklabels([]) else: ax.set_xlabel('Time relative to P (seconds)') # Add subtle grid ax.grid(True, alpha=0.2) # Update box styling to show all borders for spine in ax.spines.values(): spine.set_visible(True) spine.set_linewidth(0.5) # Add padding to the plot ax.margins(x=0.05) # Increased padding to 5% on left and right # Update title net, sta = station_code.split(".") #fig.suptitle(f"Station {station_code} - Multiple Events View", # fontsize=10, y=0.98) return fig
[docs] def render(self): """Render the waveform display interface. This method creates the main UI for waveform visualization, including: - View type selection (Event/Station view) - Waveform display - Missing data information """ view_type = st.radio( "Select View Type", ["Single Event - Multiple Stations", "Single Station - Multiple Events"], key="view_selector_waveform" ) if not self.stream: st.info("No waveforms to display. Use the 'Get Waveforms' button to retrieve waveforms.") return if view_type == "Single Event - Multiple Stations": events = self.settings.event.selected_catalogs if not events: st.warning("No events available.") return # a list of event resource ids existing_event_resource_ids = [eq.resource_id for eq in events] # a list of EQ resource_ids to confirm what data actually exists existing_data_resource_ids = list(set([tr.stats.resource_id for tr in self.stream])) # map the events.. hanging onto the original indexes so users can keep track of what's happening valid_events_with_indices = [(i, event) for i, event in enumerate(events) if hasattr(event, 'resource_id') and event.resource_id in existing_data_resource_ids] event_options = [ f"Event {orig_idx+1}: {event.origins[0].time} " f"M{event.magnitudes[0].mag if hasattr(event, 'magnitudes') and event.magnitudes else 0.99:.1f} " f"{event.extra.get('region', {}).get('value', 'Unknown Region')}" for orig_idx, event in valid_events_with_indices ] valid_events = [event for _, event in valid_events_with_indices] selected_event_idx = st.selectbox( "Select Event", range(len(event_options)), format_func=lambda x: event_options[x] ) selected_event = valid_events[selected_event_idx] event_stream = Stream([tr for tr in self.stream if tr.stats.resource_id == selected_event.resource_id.id]) filtered_stream = self.apply_filters(event_stream) if filtered_stream: # Calculate pagination num_pages = (len(filtered_stream) - 1) // self.filter_menu.display_limit + 1 page = st.sidebar.selectbox( "Page Navigation", range(1, num_pages + 1), key="event_view_pagination" ) - 1 fig = self.plot_event_view( selected_event, filtered_stream, page, num_pages ) if fig: st.session_state.current_figure = fig st.pyplot(fig) else: st.warning("No waveforms match the current filter criteria.") else: # Single Station - Multiple Events view if not self.stream: st.warning("No traces available.") return # Get unique stations from all traces filtered_stream = self.apply_filters(self.stream) stations = set([f"{tr.stats.network}.{tr.stats.station}" for tr in filtered_stream]) if not stations: st.warning("No stations match the current filter criteria.") return # may have to check if there exists data for said station (TODO) station_options = sorted(list(stations)) selected_station = st.selectbox( "Select Station", station_options ) if selected_station: net, sta = selected_station.split(".") # select for pertinent station station_stream = filtered_stream.select(network=net,station=sta) if station_stream: # Calculate pagination num_pages = (len(station_stream) - 1) // self.filter_menu.display_limit + 1 page = st.sidebar.selectbox( "Page Navigation", range(1, num_pages + 1), key="station_view_pagination" ) - 1 # Use plot_station_view try: fig = self.plot_station_view(selected_station, station_stream, page, num_pages) except Exception as e: print(f"waveform.plot_station_view issue:\n {e}") if fig: st.session_state.current_figure = fig st.pyplot(fig) else: st.warning("No waveforms available for the selected station.") # Create missing data display before checking stream missing_data_display = MissingDataDisplay( self.stream, self.missing_data, self.settings ) missing_data_display.render()
[docs] class WaveformComponents: settings: SeismoLoaderSettings filter_menu: WaveformFilterMenu waveform_display: WaveformDisplay continuous_components: ContinuousComponents console: ConsoleDisplay def __init__(self, settings: SeismoLoaderSettings): self.settings = settings self.filter_menu = WaveformFilterMenu(settings) self.waveform_display = WaveformDisplay(settings, self.filter_menu) self.continuous_components = ContinuousComponents(settings) self.console = ConsoleDisplay() # Initialize console with logs from session state if they exist if "log_entries" in st.session_state and st.session_state["log_entries"]: self.console.accumulated_output = st.session_state["log_entries"] # Pass console to WaveformDisplay self.waveform_display.console = self.console # Initialize session state required_states = { "is_downloading": False, "query_done": False, "polling_active": False, "query_thread": None, "trigger_rerun": False, "log_entries": [] } for key, val in required_states.items(): if key not in st.session_state: st.session_state[key] = val
[docs] def render_polling_ui(self): """ Handles UI updates while monitoring background thread status """ if st.session_state.get("is_downloading", False): query_thread = st.session_state.get("query_thread") # Process any new log entries from the queue new_logs = False while not log_queue.empty(): try: log_entry = log_queue.get_nowait() if not self.console.accumulated_output: self.console.accumulated_output = [] self.console.accumulated_output.append(log_entry) new_logs = True except queue.Empty: break # Save logs to session state if updated if new_logs: st.session_state["log_entries"] = self.console.accumulated_output # Trigger rerun to update the UI with new logs st.rerun() if query_thread and not query_thread.is_alive(): try: query_thread.join() except Exception as e: st.error(f"Error in background thread: {e}") # Add error to console output if not self.console.accumulated_output: self.console.accumulated_output = [] self.console.accumulated_output.append(f"Error: {str(e)}") st.session_state["log_entries"] = self.console.accumulated_output st.session_state.update({ "is_downloading": False, "query_done": True, "query_thread": None, "polling_active": False }) st.rerun() # Always trigger a rerun while polling is active to check for new logs if st.session_state.get("polling_active"): time.sleep(0.2) # Shorter pause for more frequent updates st.rerun()
[docs] def render(self): if self.settings.selected_workflow == WorkflowType.CONTINUOUS: self.continuous_components.render() return # Initialize tab selection in session state if not exists if "active_tab" not in st.session_state: st.session_state["active_tab"] = 0 # Default to waveform tab # Auto-switch to log tab during download if new logs are available if st.session_state.get("is_downloading", False) and log_queue.qsize() > 0: st.session_state["active_tab"] = 0 # Keep on waveform tab to show real-time logs # Create tabs for Waveform and Log views tab_names = ["📊 Waveform View", "📝 Log View"] waveform_tab, log_tab = st.tabs(tab_names) # Get the current stream and update available channels before rendering filter menu current_stream = None if self.waveform_display.stream: # The stream can be either a list of traces or a Stream object # We need to pass the actual stream data to update_available_channels current_stream = self.waveform_display.stream # Update available channels with the current stream self.filter_menu.update_available_channels(current_stream) # Always render filter menu (sidebar) first self.filter_menu.render(current_stream) # Handle content based on active tab with waveform_tab: self._render_waveform_view() with log_tab: # If we're switching to log tab and download is complete, # make sure all logs are transferred from queue to accumulated_output if not st.session_state.get("is_downloading", False): # Process any remaining logs in the queue while not log_queue.empty(): try: log_entry = log_queue.get_nowait() if not self.console.accumulated_output: self.console.accumulated_output = [] self.console.accumulated_output.append(log_entry) except queue.Empty: break # Save to session state if self.console.accumulated_output: st.session_state["log_entries"] = self.console.accumulated_output self._render_log_view()
def _render_waveform_view(self): st.title("Event Arrivals") # Create three columns for the controls col1, col2, col3 = st.columns(3) # Force Re-download toggle in first column with col1: self.settings.waveform.force_redownload = st.toggle( "Force Re-download", value=self.settings.waveform.force_redownload, help="If turned off, the app will try to avoid " "downloading data that are already available locally." " If flagged, it will redownload the data again." ) # Get Waveforms button in second column with col2: get_waveforms_button = st.button( "Get Waveforms", key="get_waveforms", disabled=st.session_state.get("is_downloading", False), use_container_width=True ) # Cancel Download button in third column with col3: if st.button("Cancel Download", key="cancel_download", disabled=not st.session_state.get("is_downloading", False), use_container_width=True): stop_event.set() # Signal cancellation st.warning("Cancelling download...") st.session_state.update({ "download_cancelled": True # Set cancellation flag }) st.rerun() # Download status indicator status_container = st.empty() # Show appropriate status message if get_waveforms_button: status_container.info("Starting waveform download...") self.waveform_display.retrieve_waveforms() elif st.session_state.get("is_downloading"): # status_container.info("Downloading waveforms... (this may take several minutes)") st.spinner("Downloading waveforms... (this may take several minutes)") # Display real-time logs in the waveform view during download log_container = st.empty() # Process any new log entries from the queue new_logs = False while not log_queue.empty(): try: log_entry = log_queue.get_nowait() if not self.console.accumulated_output: self.console.accumulated_output = [] self.console.accumulated_output.append(log_entry) new_logs = True except queue.Empty: break # Save logs to session state if updated if new_logs or self.console.accumulated_output: st.session_state["log_entries"] = self.console.accumulated_output # Display logs in the waveform view if self.console.accumulated_output: # Add the initial header line if it's not already there if not any("Running run_event" in line for line in self.console.accumulated_output): self.console.accumulated_output.insert(0, "Running run_event\n-----------------") st.session_state["log_entries"] = self.console.accumulated_output raw_content = "".join(self.console.accumulated_output) escaped_content = escape(raw_content) log_text = ( '<div class="terminal" id="log-terminal" style="max-height: 700px;">' f'<pre style="margin: 0; white-space: pre; tab-size: 4;">{escaped_content}</pre>' '</div>' '<script>' 'if (window.terminal_scroll === undefined) {' ' window.terminal_scroll = function() {' ' var terminalDiv = document.getElementById("log-terminal");' ' if (terminalDiv) {' ' terminalDiv.scrollTop = terminalDiv.scrollHeight;' ' }' ' };' '}' 'window.terminal_scroll();' '</script>' ) log_container.markdown(log_text, unsafe_allow_html=True) self.render_polling_ui() elif st.session_state.get("query_done") and self.waveform_display.stream: status_container.success(f"Successfully retrieved waveforms for {len(self.waveform_display.stream)} channels.") elif st.session_state.get("query_done"): if st.session_state.get("download_cancelled", False): status_container.warning("Waveform download was cancelled by user.") # Reset the flag after displaying st.session_state["download_cancelled"] = False else: status_container.warning("No waveforms retrieved. Please check your selection criteria and log view.") # Display waveforms if they exist if self.waveform_display.stream: self.waveform_display.render() # Add download button at the bottom of the sidebar with st.sidebar: st.markdown("---") if st.session_state.get("current_figure") is not None: import io buf = io.BytesIO() st.session_state.current_figure.savefig(buf, format='png', dpi=300, bbox_inches='tight') buf.seek(0) st.download_button( label="Download PNG", data=buf, file_name="waveform_plot.png", mime="image/png", use_container_width=True ) else: st.button("Download PNG", disabled=True, use_container_width=True) def _render_log_view(self): st.title("Waveform Retrieval Logs") self.console._init_terminal_style() # Initialize terminal styling # Process any pending log entries from the queue logs_updated = False while not log_queue.empty(): try: log_entry = log_queue.get_nowait() if not self.console.accumulated_output: self.console.accumulated_output = [] self.console.accumulated_output.append(log_entry) logs_updated = True except queue.Empty: break # Save logs to session state if updated if logs_updated: st.session_state["log_entries"] = self.console.accumulated_output if self.console.accumulated_output: # Add the initial header line if it's not already there if not any("Running run_event" in line for line in self.console.accumulated_output): self.console.accumulated_output.insert(0, "Running run_event\n-----------------") st.session_state["log_entries"] = self.console.accumulated_output raw_content = "".join(self.console.accumulated_output) escaped_content = escape(raw_content) log_text = ( '<div class="terminal" id="log-terminal">' f'<pre style="margin: 0; white-space: pre; tab-size: 4;">{escaped_content}</pre>' '</div>' '<script>' 'if (window.terminal_scroll === undefined) {' ' window.terminal_scroll = function() {' ' var terminalDiv = document.getElementById("log-terminal");' ' if (terminalDiv) {' ' terminalDiv.scrollTop = terminalDiv.scrollHeight;' ' }' ' };' '}' 'window.terminal_scroll();' '</script>' ) st.markdown(log_text, unsafe_allow_html=True) else: st.info("Perform a waveform download first :)")
[docs] class MissingDataDisplay: """A component for displaying information about missing waveform data. This class provides a user interface for showing which events or stations have missing data and what specific channels are missing. Attributes: stream (List[Stream]): List of waveform streams. missing_data (Dict): Dictionary tracking missing data. settings (SeismoLoaderSettings): Configuration settings for seismic data processing. """ def __init__(self, stream: List[Stream], missing_data: Dict[str, Union[List[str], str]], settings: SeismoLoaderSettings): """Initialize the MissingDataDisplay component. Args: stream (List[Stream]): List of waveform streams. missing_data (Dict[str, Union[List[str], str]]): Dictionary mapping event IDs to missing data information. settings (SeismoLoaderSettings): Configuration settings for seismic data processing. """ self.stream = stream #is this needed? i think we can drop it TODO self.missing_data = missing_data self.settings = settings def _format_event_time(self, event) -> str: """Format event time in a readable way. Args: event: Event object containing event information. Returns: str: Formatted event time string. """ return event.origins[0].time.strftime('%Y-%m-%d %H:%M:%S') def _get_missing_events(self): """Identify events with no data and their missing channels. Returns: List[Dict]: List of dictionaries containing information about events with missing data, including event ID, time, magnitude, region, and missing channels. """ missing_events = [] # sort events by time try: catalog = self.settings.event.selected_catalogs.copy() #need copy? catalog.events.sort(key=lambda x: getattr(x.origins[0], 'time', UTCDateTime(0)) if x.origins else UTCDateTime(0)) except Exception as e: print("catalog sort problem",e) for event in catalog: resource_id = str(event.resource_id) # Create a string for NSLCs which should have been downloaded (e.g. within search radius) but weren't (e.g. missing on server) try: if resource_id not in self.missing_data.keys(): continue results = [] for station_key, value in self.missing_data[resource_id].items(): if value == "ALL": results.append(f"{station_key}.*") # Indicate all channels missing elif value == '': continue elif isinstance(value, list): if value: # If list not empty results.extend(value) # Add all missing channels if results: missing_data_str = ','.join(results) else: missing_data_str = None except Exception as e: missing_data_str = None print("DEBUG missing data dict issue: ",e) if missing_data_str: # Combine event ot, mag, region into one column event_str = f"{self._format_event_time(event)}, M{event.magnitudes[0].mag:.1f}, {event.extra.get('region', {}).get('value', 'Unknown Region')}" # Event completely missing missing_events.append({ 'EQ Resource ID': resource_id, 'Event': event_str, 'Missing Data': missing_data_str }) return missing_events
[docs] def render(self): """Render the missing data display interface. This method creates a table showing events with missing data, including: - Event information (time, magnitude, region) - Missing channel information - Dynamic height adjustment based on number of entries """ missing_events = self._get_missing_events() if missing_events: st.warning("⚠️ Events with Missing Data:") # Create DataFrame from missing events df = pd.DataFrame(missing_events) # Calculate dynamic height based on number of rows height = len(df) * 35 + 40 # Same formula as distance display # Display the DataFrame st.dataframe( df, use_container_width=True, height=height, hide_index=True )
if st.session_state.get("trigger_rerun", False): st.session_state["trigger_rerun"] = False # Reset flag to prevent infinite loops st.rerun() # 🔹 Force UI update