File tree Expand file tree Collapse file tree 8 files changed +45
-31
lines changed Expand file tree Collapse file tree 8 files changed +45
-31
lines changed Original file line number Diff line number Diff line change 55
55
pylint --indent-string=' ' jetstream_pt/ benchmarks/
56
56
- name : Format check with pyink
57
57
run : |
58
- pyink --pyink-indentation 2 --line-length 80 --check --verbose .
58
+ pyink --pyink-indentation 2 --line-length 80 --check --verbose --extend-exclude=deps .
59
59
60
60
cpu :
61
61
name : " jetstream_pt unit tests"
79
79
JAX_PLATFORMS=cpu coverage run -m unittest -v
80
80
- name : Create test coverage report
81
81
run : |
82
- coverage report -m
82
+ coverage report -m
83
+
84
+ interactive :
85
+ name : " jetstream_pt run interactive"
86
+ strategy :
87
+ matrix :
88
+ os : [ubuntu-20.04]
89
+ python-version : ['3.10']
90
+ runs-on : ${{ matrix.os }}
91
+ steps :
92
+ - name : Checkout
93
+ uses : actions/checkout@v4
94
+ - name : Setup Python
95
+ uses : actions/setup-python@v4
96
+ with :
97
+ python-version : ${{ matrix.python-version }}
98
+ - name : Install Dependencies
99
+ run : |
100
+ source install_everything.sh
101
+ - name : Run interactive (bf16)
102
+ run : |
103
+ JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0
104
+ - name : Run interactive (int8)
105
+ run : |
106
+ JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_kv_cache=1
Original file line number Diff line number Diff line change 1
- # source dependencies
2
- deps /
3
-
4
1
# Byte-compiled / optimized / DLL files
5
2
__pycache__ /
6
3
* .py [cod ]
Original file line number Diff line number Diff line change
1
+ [submodule "deps/JetStream "]
2
+ path = deps/JetStream
3
+ url = https://github.com/google/JetStream.git
4
+ [submodule "deps/xla "]
5
+ path = deps/xla
6
+ url = https://github.com/pytorch/xla.git
Original file line number Diff line number Diff line change 12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024
16
- JETSTREAM_TAG=e4952fbb12e0ab3c33bc7c1eef3839b7c2ad0dd4 # updated May 16, 2024
17
-
18
15
# Uninstall existing jax
19
16
pip show jax && pip uninstall -y jax
20
17
pip show jaxlib && pip uninstall -y jaxlib
@@ -26,17 +23,5 @@ pip install torch --index-url https://download.pytorch.org/whl/cpu
26
23
pip install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
27
24
pip install safetensors colorama coverage ray[default] humanize
28
25
29
- mkdir -p deps
30
- pushd deps
31
- git clone https://github.com/google/JetStream.git
32
- git clone https://github.com/pytorch/xla.git
33
- pushd xla/experimental/torch_xla2
34
- git checkout $TORCHXLA_TAG
35
- pip install .
36
- popd # now at the folder deps
37
- pushd JetStream
38
- git checkout $JETSTREAM_TAG
39
- pip install .
40
- popd # now at the folder deps
41
- popd # now at the folder current file
26
+ git submodule update --init --recursive
42
27
pip install -e .
Original file line number Diff line number Diff line change @@ -3,7 +3,7 @@ requires = ["hatchling"]
3
3
build-backend = " hatchling.build"
4
4
5
5
[project ]
6
- version = " 0.2.0 "
6
+ version = " 0.2.1 "
7
7
name = " jetstream_pt"
8
8
dependencies = [
9
9
" absl-py" ,
@@ -14,7 +14,12 @@ dependencies = [
14
14
" google-jetstream" ,
15
15
" google-cloud-storage" ,
16
16
" safetensors" ,
17
+ " torch_xla2 @ {root:uri}/deps/xla/experimental/torch_xla2" ,
18
+ " google-jetstream @ {root:uri}/deps/JetStream" ,
17
19
]
18
20
19
21
requires-python = " >=3.10"
20
22
license = {file = " LICENSE" }
23
+
24
+ [tool .hatch .metadata ]
25
+ allow-direct-references = true
Original file line number Diff line number Diff line change @@ -158,20 +158,15 @@ def main(argv):
158
158
decode_state , result_tokens = engine .generate (params , decode_state )
159
159
result_tokens = result_tokens .convert_to_numpy ()
160
160
res = result_tokens .get_result_at_slot (slot )
161
- stop_tokens = set (tokenizer .tokenizer . stop_tokens )
161
+ stop_tokens = set (tokenizer .stop_tokens )
162
162
stop_tokens .add (tokenizer .pad_id )
163
+ token_id = res .tokens [0 ][0 ].item ()
164
+ sampled_tokens_list .append (token_id )
163
165
if (
164
- res . tokens [ 0 ][ 0 ] in stop_tokens
166
+ token_id in stop_tokens
165
167
or len (sampled_tokens_list ) > max_output_length
166
168
):
167
169
break
168
- token_id = res .tokens [0 ][0 ]
169
- sampled_tokens_list .append (token_id )
170
- # output_str = tokenizer.decode_str([token_id])
171
- # print(Fore.GREEN + output_str, end="", flush=True)
172
-
173
- # print(Style.RESET_ALL + "\n")
174
- # print("---- Streaming decode finished.")
175
170
176
171
print ("---- All output tokens." )
177
172
print (sampled_tokens_list )
You can’t perform that action at this time.
0 commit comments