Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Terminal colors #905

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ node_modules
.eggs/
.env
.DS_Store
.idea

# Ignore native library built by setup
guidance/*.so
Expand Down
33 changes: 33 additions & 0 deletions guidance/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import textwrap
import types
import re

import numpy as np

Expand Down Expand Up @@ -261,3 +262,35 @@ def softmax(array: np.ndarray, axis: int = -1) -> np.ndarray:
array_maxs = np.amax(array, axis=axis, keepdims=True)
exp_x_shifted = np.exp(array - array_maxs)
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)


# Is it good to allow user to create their own instances of output?
class ReadableOutput:
def feed(self, state_list):
"""
Main function to parse model output

state_list is a list,
where [0] is the text chunk
and [1] is rgba color tuple (if text was generated) or None (if text was inserted by us)

The function must return something to print, or it will be None all the way
"""
raise NotImplementedError('"feed" must be implemented!')


class ReadableOutputCLIStream(ReadableOutput):
def __init__(self):
self._cur_chunk = 0
super().__init__()

def feed(self, state_list):
new_text = ""
for text, color in state_list[self._cur_chunk:]:
if color is not None:
new_text += '\033[38;2;{};{};{}m'.format(*[round(x) for x in color[:3]])
new_text += text
if color is not None:
new_text += '\033[0m'
self._cur_chunk += 1
return new_text
33 changes: 25 additions & 8 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,24 @@
import time
import warnings


from pprint import pprint
from typing import Dict, TYPE_CHECKING


import numpy as np

try:
from IPython import get_ipython
from IPython.display import clear_output, display, HTML

ipython_is_imported = True
except ImportError:
ipython_is_imported = False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this, or is it replaced by notebook_mode?

notebook_mode = False
else:
ipython_is_imported = True
_ipython = get_ipython()
notebook_mode = (
_ipython is not None
and "IPKernelApp" in _ipython.config
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the way that tqdm.auto determines notebook context -- would be good to test this on multiple machines/platforms...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had to tackle this detection back in InterpretML too, here's how we did it there: https://github.com/interpretml/interpret/blob/develop/python/interpret-core/interpret/provider/_environment.py. We could re-use some of this logic here perhaps

cc @nopdive

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, hadn't been watching my notifications -- code seems reasonable. Yes, this should be tested (either manually/automatically is fine) on multiple platforms before merge: terminal (Window/Linux/Mac), vscode, jupyter notebook/lab, azure/google/amazon/databricks notebooks. We should be relatively okay if these target environments work.

)

try:
import torch

Expand All @@ -39,7 +44,7 @@
)
from .. import _cpp as cpp
from ._guidance_engine_metrics import GuidanceEngineMetrics
from .._utils import softmax, CaptureEvents
from .._utils import softmax, CaptureEvents, ReadableOutputCLIStream
from .._parser import EarleyCommitParser, Parser
from .._grammar import (
GrammarFunction,
Expand Down Expand Up @@ -857,11 +862,13 @@ def __init__(self, engine, echo=True, **kwargs):
self._variables = {} # these are the state variables stored with the model
self._variables_log_probs = {} # these are the state variables stored with the model
self._cache_state = {} # mutable caching state used to save computation
self._state_list = []
self._state = "" # the current bytes that represent the state of the model
self._event_queue = None # TODO: these are for streaming results in code, but that needs implemented
self._event_parent = None
self._last_display = 0 # used to track the last display call to enable throttling
self._last_event_stream = 0 # used to track the last event streaming call to enable throttling
self._state_dict_parser = ReadableOutputCLIStream() # used to parse the state for cli display

@property
def active_role_end(self):
Expand Down Expand Up @@ -975,11 +982,11 @@ def _update_display(self, throttle=True):
else:
self._last_display = curr_time

if ipython_is_imported:
if notebook_mode:
clear_output(wait=True)
display(HTML(self._html()))
else:
pprint(self._state)
print(self._state_dict_parser.feed(self._state_list), end='', flush=True)

def reset(self, clear_variables=True):
"""This resets the state of the model object.
Expand All @@ -995,6 +1002,7 @@ def reset(self, clear_variables=True):
self._variables_log_probs = {}
return self

# Is this used anywhere?
def _repr_html_(self):
if ipython_is_imported:
clear_output(wait=True)
Expand Down Expand Up @@ -1327,9 +1335,18 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):

if len(chunk.new_bytes) > 0:
generated_value += new_text

# Add text to state list
self._state_list.append([new_text, None])

if chunk.is_generated:
lm += f"<||_html:<span style='background-color: rgba({165*(1-chunk.new_bytes_prob) + 0}, {165*chunk.new_bytes_prob + 0}, 0, {0.15}); border-radius: 3px;' title='{chunk.new_bytes_prob}'>_||>"

# If that was generated text - color it
self._state_list[-1][1] = (165 * (1 - chunk.new_bytes_prob) + 0, 165 * chunk.new_bytes_prob + 0, 0, 0.15)

lm += new_text

if chunk.is_generated:
lm += "<||_html:</span>_||>"

Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
from guidance import gen, models, optional, select


def test_readable_output():
model = models.Mock()
model += "Not colored " + select(options=["colored", "coloblue", "cologreen"])
assert str(model) in ["Not colored colored", "Not colored coloblue", "Not colored cologreen"]


def test_select_reset_pos():
model = models.Mock()
model += "This is" + select(options=["bad", "quite bad"])
Expand Down
Loading