Source code for stream.visuals.plot_NAM
import numpy as np
import plotly.graph_objects as go
import torch
from dash import Dash, Input, Output, dcc, html
def plot_with_plotly(feature_nn, feature_name, feature_values, y_min, y_max):
with torch.no_grad():
# Determine if the feature is likely continuous or categorical/binary
unique_values = np.unique(feature_values)
if len(unique_values) > 10: # Arbitrary threshold, adjust based on your data
# Treat as continuous
# Generate simulated 'x' values within the range of observed values for smoother plotting
x_simulated = np.linspace(
feature_values.min(), feature_values.max(), 250)
else:
# Treat as categorical/binary, use unique values directly
x_simulated = unique_values
# Convert simulated 'x' values to a Tensor
x_tensor = torch.Tensor(x_simulated).unsqueeze(1)
# Make predictions using the simulated 'x' values
feature_nn.eval()
with torch.no_grad():
predictions = feature_nn.forward(x_tensor).numpy().flatten()
# Create a Plotly figure
fig = go.Figure()
# Plot the simulated 'x' values against the predicted 'y' values
fig.add_trace(
go.Scatter(
x=x_simulated,
y=predictions,
mode="lines",
name=feature_name,
line=dict(color="red"),
)
)
# Update layout with feature name as y-axis label and set y-axis range
fig.update_layout(
yaxis_title=feature_name,
title=f"Feature: {feature_name}",
xaxis_title="Input Range",
# yaxis=dict(range=[y_min, y_max]),
)
return fig
[docs]def plot_downstream_model(downstream_model):
app = Dash(__name__)
feature_names = downstream_model.combined_data.columns[
:-1
].tolist() # Exclude the target column if present
y_min = np.min(
downstream_model.combined_data[downstream_model.combined_data.columns[-1]]
)
y_max = np.max(
downstream_model.combined_data[downstream_model.combined_data.columns[-1]]
)
app.layout = html.Div(
[
html.H1("Feature-specific Neural Network Functions"),
dcc.Dropdown(
id="feature-dropdown",
options=[{"label": name, "value": name}
for name in feature_names],
value=feature_names[0], # Default value
),
dcc.Graph(id="feature-graph"),
]
)
@app.callback(
Output("feature-graph", "figure"), [Input("feature-dropdown", "value")]
)
def update_graph(selected_feature):
feature_index = feature_names.index(selected_feature)
feature_nn = downstream_model.model.feature_nns[feature_index]
# Retrieve the true data values for the selected feature
feature_values = downstream_model.combined_data[selected_feature]
# Pass the true data values and global y-axis range to the plot function
return plot_with_plotly(
feature_nn, selected_feature, feature_values, y_min, y_max
)
app.run_server(debug=True)