core

modifying streamlit methods to play nice with jupyter

tqdm patch

Calling tqdm as tqdm.notebook or stqdm depending on environment

class StreamlitPatcher:
    """class to patch streamlit functions for displaying content in jupyter notebooks"""

    def __init__(self):
        self.is_registered: bool = False
        self.registered_methods: tp.Set[str] = set()

    def jupyter(self):
        """patches streamlit methods to display content in jupyter notebooks"""
        # patch streamlit methods from MAPPING property dict
        for method_name, wrapper in self.MAPPING.items():
            self._wrap(method_name, wrapper)

        self.is_registered = True

    @staticmethod
    def _get_streamlit_methods():
        """get all streamlit methods"""
        return [attr for attr in dir(st) if not attr.startswith("_")]

source

StreamlitPatcher

 StreamlitPatcher ()

class to patch streamlit functions for displaying content in jupyter notebooks

@patch_to(StreamlitPatcher, cls_method=False)
def _wrap(
    cls,
    method_name: str,
    wrapper: tp.Callable,
) -> None:
    """make a streamlit method jupyter friendly

    Parameters
    ----------
    method_name : str
        which method to jupyterify
    wrapper : tp.Callable
        wrapper function to use
    """
    if IN_IPYTHON:  # only patch if in jupyter
        trg = getattr(st, method_name)  # get the streamlit method
        setattr(st, method_name, wrapper(trg))  # patch the method
        cls.registered_methods.add(method_name)  # add to registered methods
sp = StreamlitPatcher()

assert not sp.is_registered, "StreamlitPatcher is already registered"

Modifying streamlit

The way we will modify streamlit methods is by putting them through a decorator. This decorator will check if we are in a jupyter notebook, and if so, it will take the input and display it in the notebook.

Else it will use the original streamlit method.

st.write

sp._wrap("write", _st_write)

with capture_output() as cap:
    st.write("hello")
    got = cap._outputs[0]["data"]

expected = {
    "text/plain": "<IPython.core.display.Markdown object>",
    "text/markdown": "hello",
}
assert got == expected, "check that the output is correct"

st.write("hello")

hello

st.write("This is **bold** text in markdown")

This is bold text in markdown

try:
    df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
    st.write(df)
except ImportError:
    logger.warning("Pandas not installed, skipping test")
a b
0 1 4
1 2 5
2 3 6
assert sp.registered_methods == {"write"}, "check that the method is registered"

patching headings

  • st.title
  • st.header
  • st.subheader
sp = StreamlitPatcher()
sp._wrap("title", functools.partial(_st_heading, tag="#"))
sp._wrap("header", functools.partial(_st_heading, tag="##"))
sp._wrap("subheader", functools.partial(_st_heading, tag="###"))
with capture_output() as cap:
    st.title("foo")
    got = cap._outputs[0]["data"]["text/markdown"]

test_eq(got, "# foo")
with capture_output() as cap:
    st.header("foo")
    got = cap._outputs[0]["data"]["text/markdown"]

test_eq(got, "## foo")
with capture_output() as cap:
    st.subheader("foo")
    got = cap._outputs[0]["data"]["text/markdown"]

test_eq(got, "### foo")
# these should fail

test_fail(lambda: st.title(df), contains="Unsupported type")
test_fail(lambda: st.header(df), contains="Unsupported type")
test_fail(lambda: st.subheader(df), contains="Unsupported type")
test_fail(lambda: st.subheader(1), contains="Unsupported type")

st.caption

st.caption("This is a string that explains something above.")
st.caption("A caption with _italics_ :blue[colors] and emojis :sunglasses:")
st.caption("A caption with \n newlines")

This is a string that explains something above.

A caption with italics :blue[colors] and emojis :sunglasses:

A caption with newlines

patch some methods to simply display the input in jupyter

sp._wrap("markdown", functools.partial(_st_type_check, allowed_types=str))

test_fail(lambda: st.markdown(df), contains="Unsupported type")
st.markdown("This is **bold** text in markdown")

This is bold text in markdown

sp._wrap("dataframe", functools.partial(_st_type_check, allowed_types=pd.DataFrame))
test_fail(lambda: st.dataframe("foo"), contains="Unsupported type")
st.dataframe(df)
a b
0 1 4
1 2 5
2 3 6

st.code

st.code(
    """
def foo():
    print('hello')
"""
)

def foo():
    print('hello')
st.code("grep -r 'foo' .", language=None)
grep -r 'foo' .

st.text

st.latex

sp._wrap("latex", _st_latex)  # |hide_line
st.latex(r"E=mc^2")

\[\begin{equation}E=mc^2\end{equation}\]

st.latex(
    r"""a + ar + a r^2 + a r^3 + \cdots + a r^{n-1} =
        \sum_{k=0}^{n-1} ar^k =
        a \left(\frac{1-r^{n}}{1-r}\right)
"""
)

\[\begin{equation}a + ar + a r^2 + a r^3 + \cdots + a r^{n-1} = \sum_{k=0}^{n-1} ar^k = a \left(\frac{1-r^{n}}{1-r}\right) \end{equation}\]

st.json

Testing output of st.json with dict

body = {"foo": "bar", "baz": [1, 2, 3]}
expected = '```json\n{\n  "foo": "bar",\n  "baz": [\n    1,\n    2,\n    3\n  ]\n}\n```'  # |hide_line
test_md_output(st.json, expected, body)  # |hide_line
st.json(body)
{
  "foo": "bar",
  "baz": [
    1,
    2,
    3
  ]
}
body = {"foo": "bar", "baz": [1, 2, 3]}
expected = '```json\n{"foo": "bar", "baz": [1, 2, 3]}\n```'  # |hide_line
test_md_output(st.json, expected, body, expanded=False)  # |hide_line
st.json(body, expanded=False)
{"foo": "bar", "baz": [1, 2, 3]}

Testing output of st.json with str

body = '{"foo": "bar", "baz": [1,2,3]}'
expected = '```json\n{\n  "foo": "bar",\n  "baz": [\n    1,\n    2,\n    3\n  ]\n}\n```'  # |hide_line
test_md_output(st.json, expected, body)  # |hide_line
st.json(body)
{
  "foo": "bar",
  "baz": [
    1,
    2,
    3
  ]
}
body = '{"foo": "bar", "baz": [1,2,3]}'
expected = '```json\n{"foo": "bar", "baz": [1,2,3]}\n```'  # |hide_line
test_md_output(st.json, expected, body, expanded=False)  # |hide_line
st.json(body, expanded=False)
{"foo": "bar", "baz": [1,2,3]}

st.cache, st.cache_data, st.cache_resource

The streamlitcache method is used to cache the output of a function. This is useful for functions that take a long time to run, and we want to avoid running them every time we run the app.

If we are in a jupyter notebook, we can’t use the streamlitcache method, so we will replace the streamlitcache method with a dummy method that does nothing.

sp._wrap("cache", _dummy_wrapper_noop)
sp._wrap("cache_data", _dummy_wrapper_noop)
sp._wrap("cache_resource", _dummy_wrapper_noop)
# verify that during patching we didn't change the name or docstring
assert st.cache.__name__ == "cache"
assert "@st.cache" in tp.cast(
    str, st.cache.__doc__
), "check that the docstring is correct"
# test caching
@st.cache_data()
def get_data():
    st.write("Getting data...")
    for i in tqdm(range(5)):
        time.sleep(0.1)
    return pd.DataFrame({"c": [7, 8, 9], "d": [10, 11, 12]})


df = get_data()
st.write(df)

Getting data…

c d
0 7 10
1 8 11
2 9 12
# test that the cache in jupyter does not affect get_data

df = get_data()
with capture_output() as cap:
    st.write(df)
    got = cap._outputs[0]["data"]

expected = {
    "text/plain": "   c   d\n0  7  10\n1  8  11\n2  9  12",
    "text/html": '<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border="1" class="dataframe">\n  <thead>\n    <tr style="text-align: right;">\n      <th></th>\n      <th>c</th>\n      <th>d</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>7</td>\n      <td>10</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>8</td>\n      <td>11</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>9</td>\n      <td>12</td>\n    </tr>\n  </tbody>\n</table>\n</div>',
}

assert got == expected, "check that the output is correct"

Getting data…

# test caching
@st.cache_resource(ttl=3600)
def get_resource():
    st.write("Getting resource...")
    for i in tqdm(range(5)):
        time.sleep(0.1)
    return {
        "foo": "bar",
        "baz": [1, 2, 3],
        "qux": {"a": 1, "b": 2, "c": 3},
    }


expected = {
    "foo": "bar",
    "baz": [1, 2, 3],
    "qux": {"a": 1, "b": 2, "c": 3},
}

got = get_resource()
assert got == expected, "check that the output is correct"

Getting resource…

# test that the cache in jupyter does not affect get_data

records = get_resource()
with capture_output() as cap:
    st.write(records)
    got = cap._outputs[0]["data"]

expected = {
    "text/plain": "{'foo': 'bar', 'baz': [1, 2, 3], 'qux': {'a': 1, 'b': 2, 'c': 3}}"
}

assert got == expected, "check that the output is correct"

Getting resource…

st.expander

Note that this will be an exception from the usual wrapper logic.

Since st.expander is used as a context manager, we replace it with a dummy class that displays the input in jupyter.

sp._wrap("expander", _st_expander)
with st.expander("Expand me!", expanded=False):
    st.markdown(
        """
The **#30DaysOfStreamlit** is a coding challenge designed to help you get started in building Streamlit apps.

Particularly, you'll be able to:
- Set up a coding environment for building Streamlit apps
- Build your first Streamlit app
- Learn about all the awesome input/output widgets to use for your Streamlit app
    """
    )

    st.write("**More text, we can expand as many streamlit elements as we want**")

expander starts: Expand me!

The #30DaysOfStreamlit is a coding challenge designed to help you get started in building Streamlit apps.

Particularly, you’ll be able to: - Set up a coding environment for building Streamlit apps - Build your first Streamlit app - Learn about all the awesome input/output widgets to use for your Streamlit app

More text, we can expand as many streamlit elements as we want

expander ends

st.text_input

sp._wrap("text_input", _st_text_input)
sp._wrap("text_area", _st_text_input)
text = st.text_input("String:", "default text")
text
'default text'
text = st.text_area("Input:", "foo bar")
text
'foo bar'

st.date_input

sp._wrap("date_input", _st_date_input)

⚠️ Note the following limitation: when using this in jupyter, changing the date on your widget will not affect the date variable.

Streamlit behavior will remain unchanged though

date = st.date_input("Pick a date", value="2022-12-13")
assert date == datetime(2022, 12, 13).date()

st.checkbox

sp._wrap("checkbox", _st_checkbox)
show_code = st.checkbox("Show code")
assert show_code
show_code = st.checkbox("Show code", value=False)
assert not show_code

_st_radio and _st_selectbox

sp._wrap(
    "radio", functools.partial(_st_single_choice, jupyter_widget=widgets.RadioButtons)
)
sp._wrap(
    "selectbox", functools.partial(_st_single_choice, jupyter_widget=widgets.Dropdown)
)
st.radio("Pick", options=["foo", "bar"], index=1, key="radio")
'bar'
st.selectbox("Choose", options=["foo", "bar"])
'foo'

st.multiselect

sp._wrap("multiselect", _st_multiselect)
st.multiselect("Multiselect: ", options=["python", "golang", "julia", "rust"])
()
st.multiselect(
    "Multiselect with defaults: ",
    options=["nbdev", "streamlit", "jupyter", "fastcore"],
    default=["jupyter", "streamlit"],
)
('jupyter', 'streamlit')

st.metric

sp._wrap("metric", _st_metric)
# test that we don't allow invalid values for delta_color and label_visibility
test_fail(
    lambda: st.metric(
        "Speed", 300, 210, delta_color="FOOBAR", label_visibility="hidden"
    ),
    contains="delta_color",
)

test_fail(
    lambda: st.metric(
        "Speed", 300, 210, delta_color="normal", label_visibility="FOOBAR"
    ),
    contains="label_visibility",
)

# display a metric
st.metric("Speed", 300, 210, delta_color="normal", label_visibility="hidden")
2023-03-06 17:34:09.265 WARNING __main__: `delta_color` argument is not supported in Jupyter notebooks, but will be applied in Streamlit
2023-03-06 17:34:09.266 WARNING __main__: `label_visibility` argument is not supported in Jupyter notebooks, but will be applied in Streamlit
2023-03-06 17:34:09.267 WARNING __main__: plotly is not installed, falling back to default st.metric implementation
To use plotly, run `pip install plotly`

st.metric widget (this will work as expected in streamlit)

st.columns

ToDo: - [ ] add support for st.columns in jupyter

# logger.warning("Not implemented yet")

StreamlitPatcher.MAPPING

Mapping is a dictionary that maps the streamlit method to the method we want to use instead.

This is used when StreamlitPatcher.jupyter() is called.

sp = StreamlitPatcher()
assert not sp.registered_methods, "registered methods should be empty at this point"

source

StreamlitPatcher.jupyter

 StreamlitPatcher.jupyter ()

patches streamlit methods to display content in jupyter notebooks

sp.jupyter()
sp.registered_methods
{'cache',
 'cache_data',
 'cache_resource',
 'caption',
 'checkbox',
 'code',
 'dataframe',
 'date_input',
 'expander',
 'header',
 'json',
 'latex',
 'markdown',
 'metric',
 'multiselect',
 'radio',
 'selectbox',
 'subheader',
 'text',
 'text_area',
 'text_input',
 'title',
 'write'}