core

modifying streamlit methods to play nice with jupyter

tqdm patch

Calling tqdm as tqdm.notebook or stqdm depending on environment


source

StreamlitPatcher

 StreamlitPatcher ()

class to patch streamlit functions for displaying content in jupyter notebooks

Exported source
class StreamlitPatcher:
    """class to patch streamlit functions for displaying content in jupyter notebooks"""

    def __init__(self):
        self.is_registered: bool = False
        self.registered_methods: tp.Set[str] = set()

    def jupyter(self):
        """patches streamlit methods to display content in jupyter notebooks"""
        # patch streamlit methods from MAPPING property dict
        for method_name, wrapper in self.MAPPING.items():
            self._wrap(method_name, wrapper)

        self.is_registered = True

    @staticmethod
    def _get_streamlit_methods():
        """get all streamlit methods"""
        return [attr for attr in dir(st) if not attr.startswith("_")]
Exported source
@patch_to(StreamlitPatcher, cls_method=False)
def _wrap(
    cls,
    method_name: str,
    wrapper: tp.Callable,
) -> None:
    """make a streamlit method jupyter friendly

    Parameters
    ----------
    method_name : str
        which method to jupyterify
    wrapper : tp.Callable
        wrapper function to use
    """
    if IN_IPYTHON:  # only patch if in jupyter
        trg = getattr(st, method_name)  # get the streamlit method
        setattr(st, method_name, wrapper(trg))  # patch the method
        cls.registered_methods.add(method_name)  # add to registered methods
sp = StreamlitPatcher()

assert not sp.is_registered, "StreamlitPatcher is already registered"

Modifying streamlit

The way we will modify streamlit methods is by putting them through a decorator. This decorator will check if we are in a jupyter notebook, and if so, it will take the input and display it in the notebook.

Else it will use the original streamlit method.

st.write

sp._wrap("write", _st_write)

with capture_output() as cap:
    st.write("hello")
    got = cap._outputs[0]["data"]

expected = {
    "text/plain": "<IPython.core.display.Markdown object>",
    "text/markdown": "hello",
}
assert got == expected, "check that the output is correct"

st.write("hello")

hello

st.write("This is **bold** text in markdown")

This is bold text in markdown

try:
    df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
    st.write(df)
except ImportError:
    logger.warning("Pandas not installed, skipping test")
a b
0 1 4
1 2 5
2 3 6
assert sp.registered_methods == {"write"}, "check that the method is registered"

patching headings

  • st.title
  • st.header
  • st.subheader
sp = StreamlitPatcher()
sp._wrap("title", functools.partial(_st_heading, tag="#"))
sp._wrap("header", functools.partial(_st_heading, tag="##"))
sp._wrap("subheader", functools.partial(_st_heading, tag="###"))
with capture_output() as cap:
    st.title("foo")
    got = cap._outputs[0]["data"]["text/markdown"]

test_eq(got, "# foo")
with capture_output() as cap:
    st.header("foo")
    got = cap._outputs[0]["data"]["text/markdown"]

test_eq(got, "## foo")
with capture_output() as cap:
    st.subheader("foo")
    got = cap._outputs[0]["data"]["text/markdown"]

test_eq(got, "### foo")
# these should fail

test_fail(lambda: st.title(df), contains="Unsupported type")
test_fail(lambda: st.header(df), contains="Unsupported type")
test_fail(lambda: st.subheader(df), contains="Unsupported type")
test_fail(lambda: st.subheader(1), contains="Unsupported type")

st.caption

st.caption("This is a string that explains something above.")
st.caption("A caption with _italics_ :blue[colors] and emojis :sunglasses:")
st.caption("A caption with \n newlines")

This is a string that explains something above.

A caption with italics :blue[colors] and emojis :sunglasses:

A caption with newlines

patch some methods to simply display the input in jupyter

sp._wrap("markdown", functools.partial(_st_type_check, allowed_types=str))

test_fail(lambda: st.markdown(df), contains="Unsupported type")
st.markdown("This is **bold** text in markdown")

This is bold text in markdown

sp._wrap("dataframe", functools.partial(_st_type_check, allowed_types=pd.DataFrame))
test_fail(lambda: st.dataframe("foo"), contains="Unsupported type")
st.dataframe(df)
a b
0 1 4
1 2 5
2 3 6

st.code

st.code(
    """
def foo():
    print('hello')
"""
)

def foo():
    print('hello')
st.code("grep -r 'foo' .", language=None)
grep -r 'foo' .

st.text

st.latex

sp._wrap("latex", _st_latex)  # |hide_line
st.latex(r"E=mc^2")

\[\begin{equation}E=mc^2\end{equation}\]

st.latex(
    r"""a + ar + a r^2 + a r^3 + \cdots + a r^{n-1} =
        \sum_{k=0}^{n-1} ar^k =
        a \left(\frac{1-r^{n}}{1-r}\right)
"""
)

\[\begin{equation}a + ar + a r^2 + a r^3 + \cdots + a r^{n-1} = \sum_{k=0}^{n-1} ar^k = a \left(\frac{1-r^{n}}{1-r}\right) \end{equation}\]

st.json

Testing output of st.json with dict

body = {"foo": "bar", "baz": [1, 2, 3]}
expected = '```json\n{\n  "foo": "bar",\n  "baz": [\n    1,\n    2,\n    3\n  ]\n}\n```'  # |hide_line
test_md_output(st.json, expected, body)  # |hide_line
st.json(body)
{
  "foo": "bar",
  "baz": [
    1,
    2,
    3
  ]
}
body = {"foo": "bar", "baz": [1, 2, 3]}
expected = '```json\n{"foo": "bar", "baz": [1, 2, 3]}\n```'  # |hide_line
test_md_output(st.json, expected, body, expanded=False)  # |hide_line
st.json(body, expanded=False)
{"foo": "bar", "baz": [1, 2, 3]}

Testing output of st.json with str

body = '{"foo": "bar", "baz": [1,2,3]}'
expected = '```json\n{\n  "foo": "bar",\n  "baz": [\n    1,\n    2,\n    3\n  ]\n}\n```'  # |hide_line
test_md_output(st.json, expected, body)  # |hide_line
st.json(body)
{
  "foo": "bar",
  "baz": [
    1,
    2,
    3
  ]
}
body = '{"foo": "bar", "baz": [1,2,3]}'
expected = '```json\n{"foo": "bar", "baz": [1,2,3]}\n```'  # |hide_line
test_md_output(st.json, expected, body, expanded=False)  # |hide_line
st.json(body, expanded=False)
{"foo": "bar", "baz": [1,2,3]}

st.cache, st.cache_data, st.cache_resource

The streamlitcache method is used to cache the output of a function. This is useful for functions that take a long time to run, and we want to avoid running them every time we run the app.

If we are in a jupyter notebook, we can’t use the streamlitcache method, so we will replace the streamlitcache method with a dummy method that does nothing.

sp._wrap("cache", _dummy_wrapper_noop)
sp._wrap("cache_data", _dummy_wrapper_noop)
sp._wrap("cache_resource", _dummy_wrapper_noop)
# verify that during patching we didn't change the name or docstring
assert st.cache.__name__ == "cache"
assert "@st.cache" in tp.cast(
    str, st.cache.__doc__
), "check that the docstring is correct"
# test caching
@st.cache_data()
def get_data():
    st.write("Getting data...")
    for i in tqdm(range(5)):
        time.sleep(0.1)
    return pd.DataFrame({"c": [7, 8, 9], "d": [10, 11, 12]})


df = get_data()
st.write(df)

Getting data…

c d
0 7 10
1 8 11
2 9 12
# test that the cache in jupyter does not affect get_data

df = get_data()
with capture_output() as cap:
    st.write(df)
    got = cap._outputs[0]["data"]

expected = {
    "text/plain": "   c   d\n0  7  10\n1  8  11\n2  9  12",
    "text/html": '<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border="1" class="dataframe">\n  <thead>\n    <tr style="text-align: right;">\n      <th></th>\n      <th>c</th>\n      <th>d</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>7</td>\n      <td>10</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>8</td>\n      <td>11</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>9</td>\n      <td>12</td>\n    </tr>\n  </tbody>\n</table>\n</div>',
}

assert got == expected, "check that the output is correct"

Getting data…

# test caching
@st.cache_resource(ttl=3600)
def get_resource():
    st.write("Getting resource...")
    for i in tqdm(range(5)):
        time.sleep(0.1)
    return {
        "foo": "bar",
        "baz": [1, 2, 3],
        "qux": {"a": 1, "b": 2, "c": 3},
    }


expected = {
    "foo": "bar",
    "baz": [1, 2, 3],
    "qux": {"a": 1, "b": 2, "c": 3},
}

got = get_resource()
assert got == expected, "check that the output is correct"

Getting resource…

# test that the cache in jupyter does not affect get_data

records = get_resource()
with capture_output() as cap:
    st.write(records)
    got = cap._outputs[0]["data"]

expected = {
    "text/plain": "{'foo': 'bar', 'baz': [1, 2, 3], 'qux': {'a': 1, 'b': 2, 'c': 3}}"
}

assert got == expected, "check that the output is correct"

Getting resource…

st.expander

Note that this will be an exception from the usual wrapper logic.

Since st.expander is used as a context manager, we replace it with a dummy class that displays the input in jupyter.

sp._wrap("expander", _st_expander)
with st.expander("Expand me!", expanded=False):
    st.markdown(
        """
The **#30DaysOfStreamlit** is a coding challenge designed to help you get started in building Streamlit apps.

Particularly, you'll be able to:
- Set up a coding environment for building Streamlit apps
- Build your first Streamlit app
- Learn about all the awesome input/output widgets to use for your Streamlit app
    """
    )

    st.write("**More text, we can expand as many streamlit elements as we want**")

expander starts: Expand me!

The #30DaysOfStreamlit is a coding challenge designed to help you get started in building Streamlit apps.

Particularly, you’ll be able to: - Set up a coding environment for building Streamlit apps - Build your first Streamlit app - Learn about all the awesome input/output widgets to use for your Streamlit app

More text, we can expand as many streamlit elements as we want

expander ends

st.text_input

sp._wrap("text_input", _st_text_input)
sp._wrap("text_area", _st_text_input)
text = st.text_input("String:", "default text")
text
'default text'
text = st.text_area("Input:", "foo bar")
text
'foo bar'

st.date_input

sp._wrap("date_input", _st_date_input)

⚠️ Note the following limitation: when using this in jupyter, changing the date on your widget will not affect the date variable.

Streamlit behavior will remain unchanged though

date = st.date_input("Pick a date", value="2022-12-13")
assert date == datetime(2022, 12, 13).date()

st.checkbox

sp._wrap("checkbox", _st_checkbox)
show_code = st.checkbox("Show code")
assert show_code
show_code = st.checkbox("Show code", value=False)
assert not show_code

_st_radio and _st_selectbox

sp._wrap(
    "radio", functools.partial(_st_single_choice, jupyter_widget=widgets.RadioButtons)
)
sp._wrap(
    "selectbox", functools.partial(_st_single_choice, jupyter_widget=widgets.Dropdown)
)
st.radio("Pick", options=["foo", "bar"], index=1, key="radio")
'bar'
st.selectbox("Choose", options=["foo", "bar"])
'foo'

st.multiselect

sp._wrap("multiselect", _st_multiselect)
st.multiselect("Multiselect: ", options=["python", "golang", "julia", "rust"])
()
st.multiselect(
    "Multiselect with defaults: ",
    options=["nbdev", "streamlit", "jupyter", "fastcore"],
    default=["jupyter", "streamlit"],
)
('jupyter', 'streamlit')

st.metric

sp._wrap("metric", _st_metric)
# test that we don't allow invalid values for delta_color and label_visibility
test_fail(
    lambda: st.metric(
        "Speed", 300, 210, delta_color="FOOBAR", label_visibility="hidden"
    ),
    contains="delta_color",
)

test_fail(
    lambda: st.metric(
        "Speed", 300, 210, delta_color="normal", label_visibility="FOOBAR"
    ),
    contains="label_visibility",
)

# display a metric
st.metric("Speed", 300, 210, delta_color="normal", label_visibility="hidden")
2023-03-06 17:34:09.265 WARNING __main__: `delta_color` argument is not supported in Jupyter notebooks, but will be applied in Streamlit
2023-03-06 17:34:09.266 WARNING __main__: `label_visibility` argument is not supported in Jupyter notebooks, but will be applied in Streamlit
2023-03-06 17:34:09.267 WARNING __main__: plotly is not installed, falling back to default st.metric implementation
To use plotly, run `pip install plotly`

st.metric widget (this will work as expected in streamlit)

st.columns

ToDo: - [ ] add support for st.columns in jupyter

# logger.warning("Not implemented yet")

StreamlitPatcher.MAPPING

Mapping is a dictionary that maps the streamlit method to the method we want to use instead.

This is used when StreamlitPatcher.jupyter() is called.

sp = StreamlitPatcher()
assert not sp.registered_methods, "registered methods should be empty at this point"

source

StreamlitPatcher.jupyter

 StreamlitPatcher.jupyter ()

patches streamlit methods to display content in jupyter notebooks

sp.jupyter()
sp.registered_methods
{'cache',
 'cache_data',
 'cache_resource',
 'caption',
 'checkbox',
 'code',
 'dataframe',
 'date_input',
 'expander',
 'header',
 'json',
 'latex',
 'markdown',
 'metric',
 'multiselect',
 'radio',
 'selectbox',
 'subheader',
 'text',
 'text_area',
 'text_input',
 'title',
 'write'}