Skip to content

Commit 0abcc03

Browse files
authored
Merge branch 'main' into u/jakerad/generic-math
2 parents c42c1e2 + 3d705bf commit 0abcc03

File tree

6 files changed

+321
-5
lines changed

6 files changed

+321
-5
lines changed

eng/Versions.props

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
8888
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
8989
<SystemDataSqlClientVersion>4.6.1</SystemDataSqlClientVersion>
90-
<SystemDataSQLiteCoreVersion>1.0.112.2</SystemDataSQLiteCoreVersion>
90+
<SystemDataSQLiteCoreVersion>1.0.113</SystemDataSQLiteCoreVersion>
9191
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>
9292
<XUnitVersion>2.4.2</XUnitVersion>
9393
<!-- Opt-out repo features -->

eng/helix.proj

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
</MSBuild>
9797

9898
<PropertyGroup>
99-
<HelixPreCommands Condition="$(IsPosixShell)">$(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $(whoami) $HELIX_WORKITEM_ROOT</HelixPreCommands>
99+
<HelixPreCommands Condition="$(IsPosixShell)">$(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $USER $HELIX_WORKITEM_ROOT</HelixPreCommands>
100100
<HelixPreCommands Condition="!$(IsPosixShell)">$(HelixPreCommands);set ML_TEST_DATADIR=%HELIX_CORRELATION_PAYLOAD%;set MICROSOFTML_RESOURCE_PATH=%HELIX_WORKITEM_ROOT%</HelixPreCommands>
101101

102102
<HelixPreCommands Condition="$(HelixTargetQueues.ToLowerInvariant().Contains('osx'))">$(HelixPreCommands);install_name_tool -change "/usr/local/opt/libomp/lib/libomp.dylib" "@loader_path/libomp.dylib" libSymSgdNative.dylib</HelixPreCommands>

src/Microsoft.Data.Analysis/DataFrame.IO.cs

+157-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Data;
8+
using System.Data.Common;
79
using System.Globalization;
810
using System.IO;
911
using System.Text;
12+
using System.Threading.Tasks;
1013

1114
namespace Microsoft.Data.Analysis
1215
{
@@ -109,12 +112,158 @@ public static DataFrame LoadCsv(string filename,
109112
}
110113
}
111114

115+
public static DataFrame LoadFrom(IEnumerable<IList<object>> vals, IList<(string, Type)> columnInfos)
116+
{
117+
var columnsCount = columnInfos.Count;
118+
var columns = new List<DataFrameColumn>(columnsCount);
119+
120+
foreach (var (name, type) in columnInfos)
121+
{
122+
var column = CreateColumn(type, name);
123+
columns.Add(column);
124+
}
125+
126+
var res = new DataFrame(columns);
127+
128+
foreach (var items in vals)
129+
{
130+
for (var c = 0; c < items.Count; c++)
131+
{
132+
items[c] = items[c];
133+
}
134+
res.Append(items, inPlace: true);
135+
}
136+
137+
return res;
138+
}
139+
140+
public void SaveTo(DataTable table)
141+
{
142+
var columnsCount = Columns.Count;
143+
144+
if (table.Columns.Count == 0)
145+
{
146+
foreach (var column in Columns)
147+
{
148+
table.Columns.Add(column.Name, column.DataType);
149+
}
150+
}
151+
else
152+
{
153+
if (table.Columns.Count != columnsCount)
154+
throw new ArgumentException();
155+
for (var c = 0; c < columnsCount; c++)
156+
{
157+
if (table.Columns[c].DataType != Columns[c].DataType)
158+
throw new ArgumentException();
159+
}
160+
}
161+
162+
var items = new object[columnsCount];
163+
foreach (var row in Rows)
164+
{
165+
for (var c = 0; c < columnsCount; c++)
166+
{
167+
items[c] = row[c] ?? DBNull.Value;
168+
}
169+
table.Rows.Add(items);
170+
}
171+
}
172+
173+
public DataTable ToTable()
174+
{
175+
var res = new DataTable();
176+
SaveTo(res);
177+
return res;
178+
}
179+
180+
public static DataFrame FromSchema(DbDataReader reader)
181+
{
182+
var columnsCount = reader.FieldCount;
183+
var columns = new DataFrameColumn[columnsCount];
184+
185+
for (var c = 0; c < columnsCount; c++)
186+
{
187+
var type = reader.GetFieldType(c);
188+
var name = reader.GetName(c);
189+
var column = CreateColumn(type, name);
190+
columns[c] = column;
191+
}
192+
193+
var res = new DataFrame(columns);
194+
return res;
195+
}
196+
197+
public static async Task<DataFrame> LoadFrom(DbDataReader reader)
198+
{
199+
var res = FromSchema(reader);
200+
var columnsCount = reader.FieldCount;
201+
202+
var items = new object[columnsCount];
203+
while (await reader.ReadAsync())
204+
{
205+
for (var c = 0; c < columnsCount; c++)
206+
{
207+
items[c] = reader.IsDBNull(c)
208+
? null
209+
: reader[c];
210+
}
211+
res.Append(items, inPlace: true);
212+
}
213+
214+
reader.Close();
215+
216+
return res;
217+
}
218+
219+
public static async Task<DataFrame> LoadFrom(DbDataAdapter adapter)
220+
{
221+
using var reader = await adapter.SelectCommand.ExecuteReaderAsync();
222+
return await LoadFrom(reader);
223+
}
224+
225+
public void SaveTo(DbDataAdapter dataAdapter, DbProviderFactory factory)
226+
{
227+
using var commandBuilder = factory.CreateCommandBuilder();
228+
commandBuilder.DataAdapter = dataAdapter;
229+
dataAdapter.InsertCommand = commandBuilder.GetInsertCommand();
230+
dataAdapter.UpdateCommand = commandBuilder.GetUpdateCommand();
231+
dataAdapter.DeleteCommand = commandBuilder.GetDeleteCommand();
232+
233+
using var table = ToTable();
234+
235+
var connection = dataAdapter.SelectCommand.Connection;
236+
var needClose = connection.TryOpen();
237+
238+
try
239+
{
240+
using var transaction = connection.BeginTransaction();
241+
try
242+
{
243+
dataAdapter.Update(table);
244+
}
245+
catch
246+
{
247+
transaction.Rollback();
248+
transaction.Dispose();
249+
throw;
250+
}
251+
transaction.Commit();
252+
}
253+
finally
254+
{
255+
if (needClose)
256+
connection.Close();
257+
}
258+
}
259+
112260
/// <summary>
113261
/// return <paramref name="columnIndex"/> of <paramref name="columnNames"/> if not null or empty, otherwise return "Column{i}" where i is <paramref name="columnIndex"/>.
114262
/// </summary>
115263
/// <param name="columnNames">column names.</param>
116264
/// <param name="columnIndex">column index.</param>
117265
/// <returns></returns>
266+
118267
private static string GetColumnName(string[] columnNames, int columnIndex)
119268
{
120269
var defaultColumnName = "Column" + columnIndex.ToString();
@@ -126,7 +275,7 @@ private static string GetColumnName(string[] columnNames, int columnIndex)
126275
return defaultColumnName;
127276
}
128277

129-
private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
278+
private static DataFrameColumn CreateColumn(Type kind, string columnName)
130279
{
131280
DataFrameColumn ret;
132281
if (kind == typeof(bool))
@@ -143,7 +292,7 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int
143292
}
144293
else if (kind == typeof(string))
145294
{
146-
ret = new StringDataFrameColumn(GetColumnName(columnNames, columnIndex), 0);
295+
ret = new StringDataFrameColumn(columnName, 0);
147296
}
148297
else if (kind == typeof(long))
149298
{
@@ -187,7 +336,7 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int
187336
}
188337
else if (kind == typeof(DateTime))
189338
{
190-
ret = new PrimitiveDataFrameColumn<DateTime>(GetColumnName(columnNames, columnIndex));
339+
ret = new PrimitiveDataFrameColumn<DateTime>(columnName);
191340
}
192341
else
193342
{
@@ -196,6 +345,11 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int
196345
return ret;
197346
}
198347

348+
private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
349+
{
350+
return CreateColumn(kind, GetColumnName(columnNames, columnIndex));
351+
}
352+
199353
private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringReader wrappedReader,
200354
char separator = ',', bool header = true,
201355
string[] columnNames = null, Type[] dataTypes = null,
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Data;
8+
using System.Data.Common;
9+
using System.Text;
10+
11+
namespace Microsoft.Data.Analysis
12+
{
13+
public static class Extensions
14+
{
15+
public static DbDataAdapter CreateDataAdapter(this DbProviderFactory factory, DbConnection connection, string tableName)
16+
{
17+
var query = connection.CreateCommand();
18+
query.CommandText = $"SELECT * FROM {tableName}";
19+
var res = factory.CreateDataAdapter();
20+
res.SelectCommand = query;
21+
return res;
22+
}
23+
24+
public static bool TryOpen(this DbConnection connection)
25+
{
26+
if (connection.State == ConnectionState.Closed)
27+
{
28+
connection.Open();
29+
return true;
30+
}
31+
else
32+
{
33+
return false;
34+
}
35+
}
36+
}
37+
}

0 commit comments

Comments
 (0)