Skip to content

Commit 92efffe

Browse files
committed
Add support for custom model names in SDG
Prior to this, the model name had to be called exactly "mixtral" but a user may choose to name it differently. This adds support for that. Signed-off-by: mprahl <[email protected]>
1 parent 77706de commit 92efffe

File tree

2 files changed

+55
-38
lines changed

2 files changed

+55
-38
lines changed

Diff for: pipeline.yaml

+38-31
Original file line numberDiff line numberDiff line change
@@ -1857,10 +1857,7 @@ deploymentSpec:
18571857
):\n # Handle where the KFP SDK is <2.12.2.\n escaped_uri\
18581858
\ = tokenizer_model_path[len(\"oci://\") :].replace(\"/\", \"_\")\n \
18591859
\ tokenizer_model_path = os.path.join(\"/oci\", escaped_uri, \"models\"\
1860-
)\n\n # A hack because InstructLab assumes the value for model_name is\
1861-
\ a valid path and the name of the model.\n os.symlink(tokenizer_model_path,\
1862-
\ os.path.join(tempfile.gettempdir(), \"mixtral\"))\n os.chdir(tempfile.gettempdir())\n\
1863-
\n if not taxonomy_repo_secret:\n username = os.getenv(\"GIT_USERNAME\"\
1860+
)\n\n if not taxonomy_repo_secret:\n username = os.getenv(\"GIT_USERNAME\"\
18641861
)\n token = os.getenv(\"GIT_TOKEN\")\n ssh_key = os.getenv(\"\
18651862
GIT_SSH_KEY\")\n else:\n print(\"SDG Repo secret specified, fetching...\"\
18661863
)\n username, token, ssh_key = fetch_secret(\n taxonomy_repo_secret,\
@@ -1931,29 +1928,38 @@ deploymentSpec:
19311928
\ cwd=taxonomy_path,\n env=env,\n )\n exec_cmd([\"\
19321929
git\", \"checkout\", repo_branch], cwd=taxonomy_path, env=env)\n\n if\
19331930
\ sdg_secret_name is None:\n api_key = os.getenv(\"api_key\")\n \
1934-
\ endpoint = os.getenv(\"endpoint\")\n else:\n print(\"\
1935-
SDG Teacher secret specified, fetching...\")\n api_key, endpoint\
1936-
\ = fetch_secret(sdg_secret_name, [\"api_token\", \"endpoint\"])\n \
1937-
\ print(\"SDG Teacher secret data retrieved.\")\n\n # Use the default\
1938-
\ SSL context since it leverages OpenSSL to use the correct CA bundle.\n\
1939-
\ http_client = httpx.Client(verify=ssl.create_default_context())\n \
1940-
\ client = openai.OpenAI(base_url=endpoint, api_key=api_key, http_client=http_client)\n\
1941-
\n taxonomy_base = \"main\" if repo_branch or (repo_pr and int(repo_pr)\
1942-
\ > 0) else \"empty\"\n\n print(\"Generating synthetic dataset for:\"\
1943-
)\n print()\n print(\n instructlab.sdg.utils.taxonomy.read_taxonomy(\n\
1944-
\ taxonomy_path, taxonomy_base, document_output_dir=f\"{sdg_path}/documents\"\
1945-
\n )\n )\n\n # Generate synthetic dataset\n # 1.0 is the\
1946-
\ default size\n if sdg_sampling_size == 1.0:\n # generate_data\
1947-
\ has a magic word for its taxonomy_base argument - 'empty'\n # it\
1948-
\ allows generating from the whole repo, see:\n # https://github.com/instructlab/sdg/blob/c6a9e74a1618b1077cd38e713b8aaed8b7c0c8ce/src/instructlab/sdg/utils/taxonomy.py#L230\n\
1931+
\ model_name = os.getenv(\"model_name\")\n endpoint = os.getenv(\"\
1932+
endpoint\")\n else:\n print(\"SDG Teacher secret specified, fetching...\"\
1933+
)\n api_key, model_name, endpoint = fetch_secret(\n sdg_secret_name,\
1934+
\ [\"api_token\", \"model_name\", \"endpoint\"]\n )\n print(\"\
1935+
SDG Teacher secret data retrieved.\")\n\n # A hack because InstructLab\
1936+
\ assumes the value for model_name is a valid path and the name of the model.\n\
1937+
\ tmp_model_path = os.path.join(tempfile.gettempdir(), model_name)\n\
1938+
\ # Since a model name can have a slash in it and InstructLab expects\
1939+
\ this to be a valid path as well, we must\n # pretend the slashes represent\
1940+
\ directories.\n if \"/\" in model_name:\n os.makedirs(os.path.dirname(tmp_model_path),\
1941+
\ exist_ok=True)\n os.symlink(tokenizer_model_path, tmp_model_path)\n\
1942+
\ os.chdir(tempfile.gettempdir())\n\n # Use the default SSL context\
1943+
\ since it leverages OpenSSL to use the correct CA bundle.\n http_client\
1944+
\ = httpx.Client(verify=ssl.create_default_context())\n client = openai.OpenAI(base_url=endpoint,\
1945+
\ api_key=api_key, http_client=http_client)\n\n taxonomy_base = \"main\"\
1946+
\ if repo_branch or (repo_pr and int(repo_pr) > 0) else \"empty\"\n\n \
1947+
\ print(\"Generating synthetic dataset for:\")\n print()\n print(\n\
1948+
\ instructlab.sdg.utils.taxonomy.read_taxonomy(\n taxonomy_path,\
1949+
\ taxonomy_base, document_output_dir=f\"{sdg_path}/documents\"\n \
1950+
\ )\n )\n\n # Generate synthetic dataset\n # 1.0 is the default\
1951+
\ size\n if sdg_sampling_size == 1.0:\n # generate_data has a\
1952+
\ magic word for its taxonomy_base argument - 'empty'\n # it allows\
1953+
\ generating from the whole repo, see:\n # https://github.com/instructlab/sdg/blob/c6a9e74a1618b1077cd38e713b8aaed8b7c0c8ce/src/instructlab/sdg/utils/taxonomy.py#L230\n\
19491954
\ instructlab.sdg.generate_data(\n client=client,\n \
19501955
\ num_instructions_to_generate=num_instructions_to_generate,\n\
19511956
\ output_dir=sdg_path,\n taxonomy=taxonomy_path,\n\
1952-
\ taxonomy_base=taxonomy_base,\n model_name=\"mixtral\"\
1953-
,\n pipeline=pipeline,\n chunk_word_count=1000,\n\
1954-
\ server_ctx_size=4096,\n batch_size=sdg_batch_size,\n\
1955-
\ num_cpus=sdg_num_cpus,\n )\n # Tweak precomputed\
1956-
\ skills data ratio if needed\n else:\n skills_recipe = \"/usr/share/instructlab/sdg/default_data_recipes/skills.yaml\"\
1957+
\ taxonomy_base=taxonomy_base,\n model_name=model_name,\n\
1958+
\ model_family=\"mixtral\",\n pipeline=pipeline,\n\
1959+
\ chunk_word_count=1000,\n server_ctx_size=4096,\n\
1960+
\ batch_size=sdg_batch_size,\n num_cpus=sdg_num_cpus,\n\
1961+
\ )\n # Tweak precomputed skills data ratio if needed\n else:\n\
1962+
\ skills_recipe = \"/usr/share/instructlab/sdg/default_data_recipes/skills.yaml\"\
19571963
\n\n def set_precomputed_skills_data_ratio(sampling_size: float,\
19581964
\ skills_recipe: str):\n if os.path.exists(skills_recipe):\n\
19591965
\ with open(skills_recipe, \"r\", encoding=\"utf-8\") as\
@@ -2015,13 +2021,14 @@ deploymentSpec:
20152021
\ client=client,\n num_instructions_to_generate=num_instructions_to_generate,\n\
20162022
\ output_dir=sdg_path,\n taxonomy=taxonomy_path,\n\
20172023
\ taxonomy_base=taxonomy_base,\n \
2018-
\ model_name=\"mixtral\",\n pipeline=pipeline,\n\
2019-
\ chunk_word_count=1000,\n \
2020-
\ server_ctx_size=4096,\n batch_size=sdg_batch_size,\n\
2021-
\ num_cpus=sdg_num_cpus,\n )\n\
2022-
\ except Exception as e:\n print(f\"Failed\
2023-
\ to set precomputed skills data ratio: {e}\")\n raise\n\
2024-
\n # Cleanup git configurations\n if git_credentials_path and os.path.exists(git_credentials_path):\n\
2024+
\ model_name=model_name,\n model_family=\"\
2025+
mixtral\",\n pipeline=pipeline,\n \
2026+
\ chunk_word_count=1000,\n server_ctx_size=4096,\n\
2027+
\ batch_size=sdg_batch_size,\n \
2028+
\ num_cpus=sdg_num_cpus,\n )\n except\
2029+
\ Exception as e:\n print(f\"Failed to set precomputed\
2030+
\ skills data ratio: {e}\")\n raise\n\n # Cleanup\
2031+
\ git configurations\n if git_credentials_path and os.path.exists(git_credentials_path):\n\
20252032
\ os.remove(git_credentials_path)\n print(f\"{git_credentials_path}\
20262033
\ deleted successfully\")\n if ssh_key_path and os.path.exists(ssh_key_path):\n\
20272034
\ os.remove(ssh_key_path)\n print(f\"{ssh_key_path} deleted\

Diff for: sdg/components.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,6 @@ def get_git_host(repo_url):
163163
escaped_uri = tokenizer_model_path[len("oci://") :].replace("/", "_")
164164
tokenizer_model_path = os.path.join("/oci", escaped_uri, "models")
165165

166-
# A hack because InstructLab assumes the value for model_name is a valid path and the name of the model.
167-
os.symlink(tokenizer_model_path, os.path.join(tempfile.gettempdir(), "mixtral"))
168-
os.chdir(tempfile.gettempdir())
169-
170166
if not taxonomy_repo_secret:
171167
username = os.getenv("GIT_USERNAME")
172168
token = os.getenv("GIT_TOKEN")
@@ -303,12 +299,24 @@ def get_git_host(repo_url):
303299

304300
if sdg_secret_name is None:
305301
api_key = os.getenv("api_key")
302+
model_name = os.getenv("model_name")
306303
endpoint = os.getenv("endpoint")
307304
else:
308305
print("SDG Teacher secret specified, fetching...")
309-
api_key, endpoint = fetch_secret(sdg_secret_name, ["api_token", "endpoint"])
306+
api_key, model_name, endpoint = fetch_secret(
307+
sdg_secret_name, ["api_token", "model_name", "endpoint"]
308+
)
310309
print("SDG Teacher secret data retrieved.")
311310

311+
# A hack because InstructLab assumes the value for model_name is a valid path and the name of the model.
312+
tmp_model_path = os.path.join(tempfile.gettempdir(), model_name)
313+
# Since a model name can have a slash in it and InstructLab expects this to be a valid path as well, we must
314+
# pretend the slashes represent directories.
315+
if "/" in model_name:
316+
os.makedirs(os.path.dirname(tmp_model_path), exist_ok=True)
317+
os.symlink(tokenizer_model_path, tmp_model_path)
318+
os.chdir(tempfile.gettempdir())
319+
312320
# Use the default SSL context since it leverages OpenSSL to use the correct CA bundle.
313321
http_client = httpx.Client(verify=ssl.create_default_context())
314322
client = openai.OpenAI(base_url=endpoint, api_key=api_key, http_client=http_client)
@@ -335,7 +343,8 @@ def get_git_host(repo_url):
335343
output_dir=sdg_path,
336344
taxonomy=taxonomy_path,
337345
taxonomy_base=taxonomy_base,
338-
model_name="mixtral",
346+
model_name=model_name,
347+
model_family="mixtral",
339348
pipeline=pipeline,
340349
chunk_word_count=1000,
341350
server_ctx_size=4096,
@@ -435,7 +444,8 @@ def set_precomputed_skills_data_ratio(sampling_size: float, skills_recipe: str):
435444
output_dir=sdg_path,
436445
taxonomy=taxonomy_path,
437446
taxonomy_base=taxonomy_base,
438-
model_name="mixtral",
447+
model_name=model_name,
448+
model_family="mixtral",
439449
pipeline=pipeline,
440450
chunk_word_count=1000,
441451
server_ctx_size=4096,

0 commit comments

Comments
 (0)