Skip to content

Commit dc66cfc

Browse files
Update /generate to not split classes & functions across cells (#1158)
* Update to ensure no hanging code cells in generated notebooks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update generate.py * Update generate.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update generate.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2e80956 commit dc66cfc

File tree

1 file changed

+30
-0
lines changed
  • packages/jupyter-ai/jupyter_ai/chat_handlers

1 file changed

+30
-0
lines changed

packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import asyncio
23
import os
34
import time
@@ -198,6 +199,15 @@ async def afill_outline(outline, llm, verbose=False):
198199
await asyncio.gather(*all_coros)
199200

200201

202+
# Check if the content of the cell is python code or not
203+
def is_not_python_code(source: str) -> bool:
204+
try:
205+
ast.parse(source)
206+
except:
207+
return True
208+
return False
209+
210+
201211
def create_notebook(outline):
202212
"""Create an nbformat Notebook object for a notebook outline."""
203213
nbf = nbformat.v4
@@ -212,6 +222,26 @@ def create_notebook(outline):
212222
nb["cells"].append(nbf.new_markdown_cell("## " + section["title"]))
213223
for code_block in section["code"].split("\n\n"):
214224
nb["cells"].append(nbf.new_code_cell(code_block))
225+
226+
# Post process notebook for hanging code cells: merge hanging cell with the previous cell
227+
merged_cells = []
228+
for cell in nb["cells"]:
229+
# Fix a hanging code cell
230+
follows_code_cell = merged_cells and merged_cells[-1]["cell_type"] == "code"
231+
is_incomplete = cell["cell_type"] == "code" and cell["source"].startswith(" ")
232+
if follows_code_cell and is_incomplete:
233+
merged_cells[-1]["source"] = (
234+
merged_cells[-1]["source"] + "\n\n" + cell["source"]
235+
)
236+
else:
237+
merged_cells.append(cell)
238+
239+
# Fix code cells that should be markdown
240+
for cell in merged_cells:
241+
if cell["cell_type"] == "code" and is_not_python_code(cell["source"]):
242+
cell["cell_type"] = "markdown"
243+
244+
nb["cells"] = merged_cells
215245
return nb
216246

217247

0 commit comments

Comments
 (0)