Skip to content

Commit

Permalink
Basic glyph extension algorithm: Stage wrapper.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716222677
  • Loading branch information
agutkin committed Jan 16, 2025
1 parent 91660a1 commit 717cecd
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 8 deletions.
20 changes: 12 additions & 8 deletions protoscribe/evolution/confidence_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,32 @@
DEFAULT_MAX_CUMULATIVE_PROBABILITY = 1.


class Method(enum.IntEnum):
"""Method for pruning the list of best hypotheses for the test set."""
class Method(enum.StrEnum):
"""Method for pruning the list of best hypotheses for the test set.
Needs to be derived from `StrEnum` because we use this in flags which are
passed around as strings both parsed and unparsed.
"""
# No pruning. Pass through all the results.
NONE = 0
NONE = "none"

# Prune all the candidates below a certain confidence threshold, which is
# interpreted as an absolute value.
THRESHOLD = 1
THRESHOLD = "threshold"

# Convert the absolute confidence values to distribution over the test set
# and prune by probability threshold.
PROBABILITY = 2
PROBABILITY = "probability"

# Keep top-K best results according to absolute value of the confidence.
TOP_K = 3
TOP_K = "top_k"

# Keep given percentage of the best results.
TOP_PERCENTAGE = 4
TOP_PERCENTAGE = "top_percentage"

# Finds the top results with cumulative probability mass smaller than a
# specified probability threshold (technique from nucleus sampling).
TOP_P = 5
TOP_P = "top_p"


def confidence(results: dict[str, Any]) -> float:
Expand Down
4 changes: 4 additions & 0 deletions protoscribe/evolution/new_spellings_basic_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def main(unused_argv):
# update the lists of seen and unseen concepts.
spellings = []
if _PREVIOUS_SPELLINGS.value:
logging.info(
"Reading previous round spellings from %s ...",
_PREVIOUS_SPELLINGS.value
)
with open(_PREVIOUS_SPELLINGS.value) as s:
spellings = [l.strip() for l in s.readlines()]
for concept, glyph, _, _, _ in glyphs:
Expand Down
168 changes: 168 additions & 0 deletions protoscribe/evolution/stages/new_spellings_basic_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Stage wrapper over basic spelling extension algorithm."""

from collections.abc import Sequence
import logging
import os
import tempfile

from absl import app
from absl import flags
from protoscribe.evolution import new_spellings_utils # pylint: disable=unused-import Import flags.
from protoscribe.evolution.stages import common_flags
from protoscribe.evolution.stages import utils
from protoscribe.utils import file_utils
from protoscribe.utils import subprocess_utils

import glob
import os

_MODE = flags.DEFINE_enum(
"mode", "sketch-token",
[
"sketch-token",
"sketch-token-and-glyph",
],
"Type of sketch mdoel. Can be 'sketch-token' for pure sketch generation or "
"'sketch-token-and-glyph' for combined glyph and sletch prediction. "
"This is a prefix part of the model configuration in 'configs' directory."
)

_EXPERIMENT_NAME = flags.DEFINE_string(
"experiment_name", None,
"An experiment name which will define the directory in which the "
"evolving system data is placed.",
required=True
)

FLAGS = flags.FLAGS

# Actual spelling extension tool.
_NEW_SPELLINGS_TOOL = (
"protoscribe/evolution/new_spellings_basic"
)


def _results_jsonl_for_model_type(
round_data_dir: str, model_type: str, experiment_id: str
) -> str:
"""Returns results JSONL for the given model type.
Args:
round_data_dir: Data directory for this round.
model_type: Type of the model.
experiment_id: XManager job ID.
"""

# Figure out directory for the outputs.
round_id = common_flags.ROUND.value
experiment_name = (
f"{_EXPERIMENT_NAME.value}:{round_id}:{_MODE.value}_{model_type}"
)
if _MODE.value == "sketch-token":
experiment_name = f"{experiment_name}:reco"
output_dir = utils.setup_inference_directories(
round_data_dir=round_data_dir,
experiment_name=experiment_name,
experiment_id=experiment_id
)
jsonl_path = os.path.join(output_dir, "results.jsonl")
logging.info("JSONL results in %s ...", jsonl_path)
return jsonl_path


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

round_data_dir = common_flags.round_data_dir()
logging.info("Using data location: %s", round_data_dir)

# Find previous spellings extensions, if any.
prev_spellings_file = None
if common_flags.ROUND.value > 0:
prev_spellings_file = os.path.join(
common_flags.previous_data_dir(), "inference_extensions/spellings.tsv"
)
if not os.path.exists(prev_spellings_file):
raise ValueError(
f"Previous spelling extensions {prev_spellings_file} not found!"
)

# Get the paths to JSONL files containing the results from the current
# round.
semantics_results_jsonl = _results_jsonl_for_model_type(
round_data_dir=round_data_dir,
model_type=common_flags.SEMANTIC_MODEL.value,
experiment_id=common_flags.SEMANTICS_XID.value
)
phonetics_results_jsonl = _results_jsonl_for_model_type(
round_data_dir=round_data_dir,
model_type=common_flags.PHONETIC_MODEL.value,
experiment_id=common_flags.PHONETICS_XID.value
)

# Setup command-line flags to call the actual spellings extension tool.
admin_categories = f"{round_data_dir}/administrative_categories.txt"
non_admin_categories = f"{round_data_dir}/non_administrative_categories.txt"
args = [
"--data_location", round_data_dir,
"--semantic_jsonl_file", semantics_results_jsonl,
"--phonetic_jsonl_file", phonetics_results_jsonl,
"--administrative_categories", admin_categories,
"--non_administrative_categories", non_admin_categories,
# TODO: The plumbing for flags below is not great. Maybe refactor
# using protocol buffers.
"--pruning_method", FLAGS.pruning_method,
"--minimum_semantic_confidence", FLAGS.minimum_semantic_confidence,
"--minimum_phonetic_confidence", FLAGS.minimum_phonetic_confidence,
"--minimum_semantic_prob", FLAGS.minimum_semantic_prob,
"--minimum_phonetic_prob", FLAGS.minimum_phonetic_prob,
"--semantic_top_k", FLAGS.semantic_top_k,
"--phonetic_top_k", FLAGS.phonetic_top_k,
"--semantic_top_percentage", FLAGS.semantic_top_percentage,
"--phonetic_top_percentage", FLAGS.phonetic_top_percentage,
"--semantic_top_p", FLAGS.semantic_top_p,
"--phonetic_top_p", FLAGS.phonetic_top_p,
]
if common_flags.ROUND.value > 0:
args.extend(["--previous_spellings", prev_spellings_file])

# For sketches, we also need to set up the directory for outputing the
# actual glyphs as SVGs.
svg_temp_dir = tempfile.TemporaryDirectory()
if _MODE.value == "sketch-token":
output_glyph_graphics_dir = os.path.join(
round_data_dir, "glyph_extensions_svg"
)
if not os.path.exists(output_glyph_graphics_dir):
os.makedirs(output_glyph_graphics_dir, exist_ok=True)
args.extend([
"--output_glyph_graphics_dir", svg_temp_dir.name
])

# Run the algorithm.
subprocess_utils.run_subprocess(_NEW_SPELLINGS_TOOL, args=args)

# Copy the extensions from temp directory.
if _MODE.value == "sketch-token":
logging.info("Copying glyph SVGs to %s ...", output_glyph_graphics_dir)
file_utils.copy_dir(svg_temp_dir.name, output_glyph_graphics_dir)
svg_temp_dir.cleanup()


if __name__ == "__main__":
app.run(main)

0 comments on commit 717cecd

Please sign in to comment.