23
23
24
24
from openlayer import Openlayer , AsyncOpenlayer , APIResponseValidationError
25
25
from openlayer ._types import Omit
26
+ from openlayer ._utils import maybe_transform
26
27
from openlayer ._models import BaseModel , FinalRequestOptions
27
28
from openlayer ._constants import RAW_RESPONSE_HEADER
28
29
from openlayer ._exceptions import APIStatusError , APITimeoutError , APIResponseValidationError
32
33
BaseClient ,
33
34
make_request_options ,
34
35
)
36
+ from openlayer .types .inference_pipelines .data_stream_params import DataStreamParams
35
37
36
38
from .utils import update_env
37
39
@@ -730,23 +732,26 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No
730
732
"/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream" ,
731
733
body = cast (
732
734
object ,
733
- dict (
734
- config = {
735
- "input_variable_names" : ["user_query" ],
736
- "output_column_name" : "output" ,
737
- "num_of_token_column_name" : "tokens" ,
738
- "cost_column_name" : "cost" ,
739
- "timestamp_column_name" : "timestamp" ,
740
- },
741
- rows = [
742
- {
743
- "user_query" : "what is the meaning of life?" ,
744
- "output" : "42" ,
745
- "tokens" : 7 ,
746
- "cost" : 0.02 ,
747
- "timestamp" : 1610000000 ,
748
- }
749
- ],
735
+ maybe_transform (
736
+ dict (
737
+ config = {
738
+ "input_variable_names" : ["user_query" ],
739
+ "output_column_name" : "output" ,
740
+ "num_of_token_column_name" : "tokens" ,
741
+ "cost_column_name" : "cost" ,
742
+ "timestamp_column_name" : "timestamp" ,
743
+ },
744
+ rows = [
745
+ {
746
+ "user_query" : "what is the meaning of life?" ,
747
+ "output" : "42" ,
748
+ "tokens" : 7 ,
749
+ "cost" : 0.02 ,
750
+ "timestamp" : 1610000000 ,
751
+ }
752
+ ],
753
+ ),
754
+ DataStreamParams ,
750
755
),
751
756
),
752
757
cast_to = httpx .Response ,
@@ -767,23 +772,26 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non
767
772
"/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream" ,
768
773
body = cast (
769
774
object ,
770
- dict (
771
- config = {
772
- "input_variable_names" : ["user_query" ],
773
- "output_column_name" : "output" ,
774
- "num_of_token_column_name" : "tokens" ,
775
- "cost_column_name" : "cost" ,
776
- "timestamp_column_name" : "timestamp" ,
777
- },
778
- rows = [
779
- {
780
- "user_query" : "what is the meaning of life?" ,
781
- "output" : "42" ,
782
- "tokens" : 7 ,
783
- "cost" : 0.02 ,
784
- "timestamp" : 1610000000 ,
785
- }
786
- ],
775
+ maybe_transform (
776
+ dict (
777
+ config = {
778
+ "input_variable_names" : ["user_query" ],
779
+ "output_column_name" : "output" ,
780
+ "num_of_token_column_name" : "tokens" ,
781
+ "cost_column_name" : "cost" ,
782
+ "timestamp_column_name" : "timestamp" ,
783
+ },
784
+ rows = [
785
+ {
786
+ "user_query" : "what is the meaning of life?" ,
787
+ "output" : "42" ,
788
+ "tokens" : 7 ,
789
+ "cost" : 0.02 ,
790
+ "timestamp" : 1610000000 ,
791
+ }
792
+ ],
793
+ ),
794
+ DataStreamParams ,
787
795
),
788
796
),
789
797
cast_to = httpx .Response ,
@@ -1603,23 +1611,26 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter)
1603
1611
"/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream" ,
1604
1612
body = cast (
1605
1613
object ,
1606
- dict (
1607
- config = {
1608
- "input_variable_names" : ["user_query" ],
1609
- "output_column_name" : "output" ,
1610
- "num_of_token_column_name" : "tokens" ,
1611
- "cost_column_name" : "cost" ,
1612
- "timestamp_column_name" : "timestamp" ,
1613
- },
1614
- rows = [
1615
- {
1616
- "user_query" : "what is the meaning of life?" ,
1617
- "output" : "42" ,
1618
- "tokens" : 7 ,
1619
- "cost" : 0.02 ,
1620
- "timestamp" : 1610000000 ,
1621
- }
1622
- ],
1614
+ maybe_transform (
1615
+ dict (
1616
+ config = {
1617
+ "input_variable_names" : ["user_query" ],
1618
+ "output_column_name" : "output" ,
1619
+ "num_of_token_column_name" : "tokens" ,
1620
+ "cost_column_name" : "cost" ,
1621
+ "timestamp_column_name" : "timestamp" ,
1622
+ },
1623
+ rows = [
1624
+ {
1625
+ "user_query" : "what is the meaning of life?" ,
1626
+ "output" : "42" ,
1627
+ "tokens" : 7 ,
1628
+ "cost" : 0.02 ,
1629
+ "timestamp" : 1610000000 ,
1630
+ }
1631
+ ],
1632
+ ),
1633
+ DataStreamParams ,
1623
1634
),
1624
1635
),
1625
1636
cast_to = httpx .Response ,
@@ -1640,23 +1651,26 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter)
1640
1651
"/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream" ,
1641
1652
body = cast (
1642
1653
object ,
1643
- dict (
1644
- config = {
1645
- "input_variable_names" : ["user_query" ],
1646
- "output_column_name" : "output" ,
1647
- "num_of_token_column_name" : "tokens" ,
1648
- "cost_column_name" : "cost" ,
1649
- "timestamp_column_name" : "timestamp" ,
1650
- },
1651
- rows = [
1652
- {
1653
- "user_query" : "what is the meaning of life?" ,
1654
- "output" : "42" ,
1655
- "tokens" : 7 ,
1656
- "cost" : 0.02 ,
1657
- "timestamp" : 1610000000 ,
1658
- }
1659
- ],
1654
+ maybe_transform (
1655
+ dict (
1656
+ config = {
1657
+ "input_variable_names" : ["user_query" ],
1658
+ "output_column_name" : "output" ,
1659
+ "num_of_token_column_name" : "tokens" ,
1660
+ "cost_column_name" : "cost" ,
1661
+ "timestamp_column_name" : "timestamp" ,
1662
+ },
1663
+ rows = [
1664
+ {
1665
+ "user_query" : "what is the meaning of life?" ,
1666
+ "output" : "42" ,
1667
+ "tokens" : 7 ,
1668
+ "cost" : 0.02 ,
1669
+ "timestamp" : 1610000000 ,
1670
+ }
1671
+ ],
1672
+ ),
1673
+ DataStreamParams ,
1660
1674
),
1661
1675
),
1662
1676
cast_to = httpx .Response ,
0 commit comments