Skip to content

Commit 0e5607c

Browse files
committed
Use git submodule
1 parent 0f4e7ae commit 0e5607c

File tree

8 files changed

+45
-31
lines changed

8 files changed

+45
-31
lines changed

.github/workflows/unit_tests.yaml

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
pylint --indent-string=' ' jetstream_pt/ benchmarks/
5656
- name: Format check with pyink
5757
run: |
58-
pyink --pyink-indentation 2 --line-length 80 --check --verbose .
58+
pyink --pyink-indentation 2 --line-length 80 --check --verbose --extend-exclude=deps .
5959
6060
cpu:
6161
name: "jetstream_pt unit tests"
@@ -79,4 +79,28 @@ jobs:
7979
JAX_PLATFORMS=cpu coverage run -m unittest -v
8080
- name: Create test coverage report
8181
run: |
82-
coverage report -m
82+
coverage report -m
83+
84+
interactive:
85+
name: "jetstream_pt run interactive"
86+
strategy:
87+
matrix:
88+
os: [ubuntu-20.04]
89+
python-version: ['3.10']
90+
runs-on: ${{ matrix.os }}
91+
steps:
92+
- name: Checkout
93+
uses: actions/checkout@v4
94+
- name: Setup Python
95+
uses: actions/setup-python@v4
96+
with:
97+
python-version: ${{ matrix.python-version }}
98+
- name: Install Dependencies
99+
run: |
100+
source install_everything.sh
101+
- name: Run interactive (bf16)
102+
run: |
103+
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0
104+
- name: Run interactive (int8)
105+
run: |
106+
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_kv_cache=1

.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# source dependencies
2-
deps/
3-
41
# Byte-compiled / optimized / DLL files
52
__pycache__/
63
*.py[cod]

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[submodule "deps/JetStream"]
2+
path = deps/JetStream
3+
url = https://github.com/google/JetStream.git
4+
[submodule "deps/xla"]
5+
path = deps/xla
6+
url = https://github.com/pytorch/xla.git

deps/JetStream

Submodule JetStream added at 8128c8a

deps/xla

Submodule xla added at f26c35c

install_everything.sh

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024
16-
JETSTREAM_TAG=e4952fbb12e0ab3c33bc7c1eef3839b7c2ad0dd4 # updated May 16, 2024
17-
1815
# Uninstall existing jax
1916
pip show jax && pip uninstall -y jax
2017
pip show jaxlib && pip uninstall -y jaxlib
@@ -26,17 +23,5 @@ pip install torch --index-url https://download.pytorch.org/whl/cpu
2623
pip install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
2724
pip install safetensors colorama coverage ray[default] humanize
2825

29-
mkdir -p deps
30-
pushd deps
31-
git clone https://github.com/google/JetStream.git
32-
git clone https://github.com/pytorch/xla.git
33-
pushd xla/experimental/torch_xla2
34-
git checkout $TORCHXLA_TAG
35-
pip install .
36-
popd # now at the folder deps
37-
pushd JetStream
38-
git checkout $JETSTREAM_TAG
39-
pip install .
40-
popd # now at the folder deps
41-
popd # now at the folder current file
26+
git submodule update --init --recursive
4227
pip install -e .

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ requires = ["hatchling"]
33
build-backend = "hatchling.build"
44

55
[project]
6-
version = "0.2.0"
6+
version = "0.2.1"
77
name = "jetstream_pt"
88
dependencies = [
99
"absl-py",
@@ -14,7 +14,12 @@ dependencies = [
1414
"google-jetstream",
1515
"google-cloud-storage",
1616
"safetensors",
17+
"torch_xla2 @ {root:uri}/deps/xla/experimental/torch_xla2",
18+
"google-jetstream @ {root:uri}/deps/JetStream",
1719
]
1820

1921
requires-python = ">=3.10"
2022
license = {file = "LICENSE"}
23+
24+
[tool.hatch.metadata]
25+
allow-direct-references = true

run_interactive.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,15 @@ def main(argv):
158158
decode_state, result_tokens = engine.generate(params, decode_state)
159159
result_tokens = result_tokens.convert_to_numpy()
160160
res = result_tokens.get_result_at_slot(slot)
161-
stop_tokens = set(tokenizer.tokenizer.stop_tokens)
161+
stop_tokens = set(tokenizer.stop_tokens)
162162
stop_tokens.add(tokenizer.pad_id)
163+
token_id = res.tokens[0][0].item()
164+
sampled_tokens_list.append(token_id)
163165
if (
164-
res.tokens[0][0] in stop_tokens
166+
token_id in stop_tokens
165167
or len(sampled_tokens_list) > max_output_length
166168
):
167169
break
168-
token_id = res.tokens[0][0]
169-
sampled_tokens_list.append(token_id)
170-
# output_str = tokenizer.decode_str([token_id])
171-
# print(Fore.GREEN + output_str, end="", flush=True)
172-
173-
# print(Style.RESET_ALL + "\n")
174-
# print("---- Streaming decode finished.")
175170

176171
print("---- All output tokens.")
177172
print(sampled_tokens_list)

0 commit comments

Comments
 (0)