class StreamlitPatcher:
"""class to patch streamlit functions for displaying content in jupyter notebooks"""
def __init__(self):
self.is_registered: bool = False
self.registered_methods: tp.Set[str] = set()
def jupyter(self):
"""patches streamlit methods to display content in jupyter notebooks"""
# patch streamlit methods from MAPPING property dict
for method_name, wrapper in self.MAPPING.items():
self._wrap(method_name, wrapper)
self.is_registered = True
@staticmethod
def _get_streamlit_methods():
"""get all streamlit methods"""
return [attr for attr in dir(st) if not attr.startswith("_")]
core
tqdm patch
Calling tqdm as tqdm.notebook or stqdm depending on environment
StreamlitPatcher
StreamlitPatcher ()
class to patch streamlit functions for displaying content in jupyter notebooks
@patch_to(StreamlitPatcher, cls_method=False)
def _wrap(
cls,str,
method_name:
wrapper: tp.Callable,-> None:
) """make a streamlit method jupyter friendly
Parameters
----------
method_name : str
which method to jupyterify
wrapper : tp.Callable
wrapper function to use
"""
if IN_IPYTHON: # only patch if in jupyter
= getattr(st, method_name) # get the streamlit method
trg setattr(st, method_name, wrapper(trg)) # patch the method
# add to registered methods cls.registered_methods.add(method_name)
= StreamlitPatcher()
sp
assert not sp.is_registered, "StreamlitPatcher is already registered"
Modifying streamlit
The way we will modify streamlit methods is by putting them through a decorator. This decorator will check if we are in a jupyter notebook, and if so, it will take the input and display it in the notebook.
Else it will use the original streamlit method.
st.write
"write", _st_write)
sp._wrap(
with capture_output() as cap:
"hello")
st.write(= cap._outputs[0]["data"]
got
= {
expected "text/plain": "<IPython.core.display.Markdown object>",
"text/markdown": "hello",
}assert got == expected, "check that the output is correct"
"hello") st.write(
hello
"This is **bold** text in markdown") st.write(
This is bold text in markdown
try:
= pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df
st.write(df)except ImportError:
"Pandas not installed, skipping test") logger.warning(
a | b | |
---|---|---|
0 | 1 | 4 |
1 | 2 | 5 |
2 | 3 | 6 |
assert sp.registered_methods == {"write"}, "check that the method is registered"
patching headings
st.title
st.header
st.subheader
= StreamlitPatcher()
sp "title", functools.partial(_st_heading, tag="#"))
sp._wrap("header", functools.partial(_st_heading, tag="##"))
sp._wrap("subheader", functools.partial(_st_heading, tag="###")) sp._wrap(
with capture_output() as cap:
"foo")
st.title(= cap._outputs[0]["data"]["text/markdown"]
got
"# foo") test_eq(got,
with capture_output() as cap:
"foo")
st.header(= cap._outputs[0]["data"]["text/markdown"]
got
"## foo") test_eq(got,
with capture_output() as cap:
"foo")
st.subheader(= cap._outputs[0]["data"]["text/markdown"]
got
"### foo") test_eq(got,
# these should fail
lambda: st.title(df), contains="Unsupported type")
test_fail(lambda: st.header(df), contains="Unsupported type")
test_fail(lambda: st.subheader(df), contains="Unsupported type")
test_fail(lambda: st.subheader(1), contains="Unsupported type") test_fail(
patch some methods to simply display the input in jupyter
"markdown", functools.partial(_st_type_check, allowed_types=str))
sp._wrap(
lambda: st.markdown(df), contains="Unsupported type")
test_fail("This is **bold** text in markdown") st.markdown(
This is bold text in markdown
"dataframe", functools.partial(_st_type_check, allowed_types=pd.DataFrame))
sp._wrap(lambda: st.dataframe("foo"), contains="Unsupported type")
test_fail( st.dataframe(df)
a | b | |
---|---|---|
0 | 1 | 4 |
1 | 2 | 5 |
2 | 3 | 6 |
st.code
st.code("""
def foo():
print('hello')
"""
)
def foo():
print('hello')
"grep -r 'foo' .", language=None) st.code(
grep -r 'foo' .
st.text
st.latex
"latex", _st_latex) # |hide_line
sp._wrap(r"E=mc^2") st.latex(
\[\begin{equation}E=mc^2\end{equation}\]
st.latex(r"""a + ar + a r^2 + a r^3 + \cdots + a r^{n-1} =
\sum_{k=0}^{n-1} ar^k =
a \left(\frac{1-r^{n}}{1-r}\right)
"""
)
\[\begin{equation}a + ar + a r^2 + a r^3 + \cdots + a r^{n-1} = \sum_{k=0}^{n-1} ar^k = a \left(\frac{1-r^{n}}{1-r}\right) \end{equation}\]
st.json
Testing output of st.json
with dict
= {"foo": "bar", "baz": [1, 2, 3]}
body = '```json\n{\n "foo": "bar",\n "baz": [\n 1,\n 2,\n 3\n ]\n}\n```' # |hide_line
expected # |hide_line
test_md_output(st.json, expected, body) st.json(body)
{
"foo": "bar",
"baz": [
1,
2,
3
]
}
= {"foo": "bar", "baz": [1, 2, 3]}
body = '```json\n{"foo": "bar", "baz": [1, 2, 3]}\n```' # |hide_line
expected =False) # |hide_line
test_md_output(st.json, expected, body, expanded=False) st.json(body, expanded
{"foo": "bar", "baz": [1, 2, 3]}
Testing output of st.json
with str
= '{"foo": "bar", "baz": [1,2,3]}'
body = '```json\n{\n "foo": "bar",\n "baz": [\n 1,\n 2,\n 3\n ]\n}\n```' # |hide_line
expected # |hide_line
test_md_output(st.json, expected, body) st.json(body)
{
"foo": "bar",
"baz": [
1,
2,
3
]
}
= '{"foo": "bar", "baz": [1,2,3]}'
body = '```json\n{"foo": "bar", "baz": [1,2,3]}\n```' # |hide_line
expected =False) # |hide_line
test_md_output(st.json, expected, body, expanded=False) st.json(body, expanded
{"foo": "bar", "baz": [1,2,3]}
st.cache, st.cache_data, st.cache_resource
The streamlitcache
method is used to cache the output of a function. This is useful for functions that take a long time to run, and we want to avoid running them every time we run the app.
If we are in a jupyter notebook, we can’t use the streamlitcache
method, so we will replace the streamlitcache
method with a dummy method that does nothing.
"cache", _dummy_wrapper_noop)
sp._wrap("cache_data", _dummy_wrapper_noop)
sp._wrap("cache_resource", _dummy_wrapper_noop) sp._wrap(
# verify that during patching we didn't change the name or docstring
assert st.cache.__name__ == "cache"
assert "@st.cache" in tp.cast(
str, st.cache.__doc__
"check that the docstring is correct" ),
# test caching
@st.cache_data()
def get_data():
"Getting data...")
st.write(for i in tqdm(range(5)):
0.1)
time.sleep(return pd.DataFrame({"c": [7, 8, 9], "d": [10, 11, 12]})
= get_data()
df st.write(df)
Getting data…
c | d | |
---|---|---|
0 | 7 | 10 |
1 | 8 | 11 |
2 | 9 | 12 |
# test that the cache in jupyter does not affect get_data
= get_data()
df with capture_output() as cap:
st.write(df)= cap._outputs[0]["data"]
got
= {
expected "text/plain": " c d\n0 7 10\n1 8 11\n2 9 12",
"text/html": '<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border="1" class="dataframe">\n <thead>\n <tr style="text-align: right;">\n <th></th>\n <th>c</th>\n <th>d</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>7</td>\n <td>10</td>\n </tr>\n <tr>\n <th>1</th>\n <td>8</td>\n <td>11</td>\n </tr>\n <tr>\n <th>2</th>\n <td>9</td>\n <td>12</td>\n </tr>\n </tbody>\n</table>\n</div>',
}
assert got == expected, "check that the output is correct"
Getting data…
# test caching
@st.cache_resource(ttl=3600)
def get_resource():
"Getting resource...")
st.write(for i in tqdm(range(5)):
0.1)
time.sleep(return {
"foo": "bar",
"baz": [1, 2, 3],
"qux": {"a": 1, "b": 2, "c": 3},
}
= {
expected "foo": "bar",
"baz": [1, 2, 3],
"qux": {"a": 1, "b": 2, "c": 3},
}
= get_resource()
got assert got == expected, "check that the output is correct"
Getting resource…
# test that the cache in jupyter does not affect get_data
= get_resource()
records with capture_output() as cap:
st.write(records)= cap._outputs[0]["data"]
got
= {
expected "text/plain": "{'foo': 'bar', 'baz': [1, 2, 3], 'qux': {'a': 1, 'b': 2, 'c': 3}}"
}
assert got == expected, "check that the output is correct"
Getting resource…
st.expander
Note that this will be an exception from the usual wrapper logic.
Since st.expander
is used as a context manager, we replace it with a dummy class that displays the input in jupyter.
"expander", _st_expander) sp._wrap(
with st.expander("Expand me!", expanded=False):
st.markdown("""
The **#30DaysOfStreamlit** is a coding challenge designed to help you get started in building Streamlit apps.
Particularly, you'll be able to:
- Set up a coding environment for building Streamlit apps
- Build your first Streamlit app
- Learn about all the awesome input/output widgets to use for your Streamlit app
"""
)
"**More text, we can expand as many streamlit elements as we want**") st.write(
expander starts: Expand me!
The #30DaysOfStreamlit is a coding challenge designed to help you get started in building Streamlit apps.
Particularly, you’ll be able to: - Set up a coding environment for building Streamlit apps - Build your first Streamlit app - Learn about all the awesome input/output widgets to use for your Streamlit app
More text, we can expand as many streamlit elements as we want
expander ends
st.text_input
"text_input", _st_text_input)
sp._wrap("text_area", _st_text_input) sp._wrap(
= st.text_input("String:", "default text")
text text
'default text'
= st.text_area("Input:", "foo bar")
text text
'foo bar'
st.date_input
"date_input", _st_date_input) sp._wrap(
⚠️ Note the following limitation: when using this in jupyter, changing the date on your widget will not affect the date variable.
Streamlit behavior will remain unchanged though
= st.date_input("Pick a date", value="2022-12-13") date
assert date == datetime(2022, 12, 13).date()
st.checkbox
"checkbox", _st_checkbox) sp._wrap(
= st.checkbox("Show code")
show_code assert show_code
= st.checkbox("Show code", value=False)
show_code assert not show_code
_st_radio and _st_selectbox
sp._wrap("radio", functools.partial(_st_single_choice, jupyter_widget=widgets.RadioButtons)
)
sp._wrap("selectbox", functools.partial(_st_single_choice, jupyter_widget=widgets.Dropdown)
)
"Pick", options=["foo", "bar"], index=1, key="radio") st.radio(
'bar'
"Choose", options=["foo", "bar"]) st.selectbox(
'foo'
st.multiselect
"multiselect", _st_multiselect) sp._wrap(
"Multiselect: ", options=["python", "golang", "julia", "rust"]) st.multiselect(
()
st.multiselect("Multiselect with defaults: ",
=["nbdev", "streamlit", "jupyter", "fastcore"],
options=["jupyter", "streamlit"],
default )
('jupyter', 'streamlit')
st.metric
"metric", _st_metric) sp._wrap(
# test that we don't allow invalid values for delta_color and label_visibility
test_fail(lambda: st.metric(
"Speed", 300, 210, delta_color="FOOBAR", label_visibility="hidden"
),="delta_color",
contains
)
test_fail(lambda: st.metric(
"Speed", 300, 210, delta_color="normal", label_visibility="FOOBAR"
),="label_visibility",
contains
)
# display a metric
"Speed", 300, 210, delta_color="normal", label_visibility="hidden") st.metric(
2023-03-06 17:34:09.265 WARNING __main__: `delta_color` argument is not supported in Jupyter notebooks, but will be applied in Streamlit
2023-03-06 17:34:09.266 WARNING __main__: `label_visibility` argument is not supported in Jupyter notebooks, but will be applied in Streamlit
2023-03-06 17:34:09.267 WARNING __main__: plotly is not installed, falling back to default st.metric implementation
To use plotly, run `pip install plotly`
st.metric widget (this will work as expected in streamlit)
st.columns
ToDo: - [ ] add support for st.columns
in jupyter
# logger.warning("Not implemented yet")
StreamlitPatcher.MAPPING
Mapping is a dictionary that maps the streamlit method to the method we want to use instead.
This is used when StreamlitPatcher.jupyter() is called.
= StreamlitPatcher() sp
assert not sp.registered_methods, "registered methods should be empty at this point"
StreamlitPatcher.jupyter
StreamlitPatcher.jupyter ()
patches streamlit methods to display content in jupyter notebooks
sp.jupyter() sp.registered_methods
{'cache',
'cache_data',
'cache_resource',
'caption',
'checkbox',
'code',
'dataframe',
'date_input',
'expander',
'header',
'json',
'latex',
'markdown',
'metric',
'multiselect',
'radio',
'selectbox',
'subheader',
'text',
'text_area',
'text_input',
'title',
'write'}