Skip to content

Commit 8a6b01a

Browse files
Modify DirectML_ESRGAN's default adapter selection.
Modify DirectML_ESRGAN sample to by default try NPU first then retry GPU. If adapter specified through param, stay with the specified param.
1 parent a2b2dd2 commit 8a6b01a

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

Samples/DirectML_ESRGAN/main.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,37 @@ int main(int argc, char** argv)
135135
("a,adapter", "Adapter name substring filter", cxxopts::value<std::string>()->default_value(""));
136136

137137
auto commandLineArgs = commandLineParams.parse(argc, argv);
138-
139-
// See helpers.h for logic to select a DXCore adapter, create DML device, and create D3D command queue.
140-
auto [dmlDevice, commandQueue] = CreateDmlDeviceAndCommandQueue(commandLineArgs["adapter"].as<std::string>());
141-
138+
std::string adapter = commandLineArgs["adapter"].as<std::string>();
139+
140+
// Check if adapter param is empty. If so, try "NPU" first then try "GPU" too.
141+
if (adapter.empty())
142+
{
143+
try
144+
{
145+
adapter = "NPU"; // First try with NPU
146+
auto [dmlDevice, commandQueue] = CreateDmlDeviceAndCommandQueue(adapter);
147+
RunModel(
148+
dmlDevice.Get(),
149+
commandQueue.Get(),
150+
commandLineArgs["model"].as<std::string>(),
151+
commandLineArgs["image"].as<std::string>()
152+
);
153+
return 0; // Exit if successful
154+
}
155+
catch (const std::exception& e)
156+
{
157+
std::cerr << "Error: " << e.what() << std::endl;
158+
std::cout << "Retrying on GPU..." << std::endl;
159+
// If NPU fails, fallback to GPU
160+
adapter = "GPU";
161+
}
162+
}
163+
164+
// Final attempt with the specified or fallback adapter
165+
auto [dmlDevice, commandQueue] = CreateDmlDeviceAndCommandQueue(adapter);
142166
RunModel(
143-
dmlDevice.Get(),
144-
commandQueue.Get(),
167+
dmlDevice.Get(),
168+
commandQueue.Get(),
145169
commandLineArgs["model"].as<std::string>(),
146170
commandLineArgs["image"].as<std::string>()
147171
);
@@ -154,4 +178,4 @@ int main(int argc, char** argv)
154178
}
155179

156180
return 0;
157-
}
181+
}

0 commit comments

Comments
 (0)