|
41 | 41 | ModelProvenanceNotFoundError,
|
42 | 42 | OCIDataScienceModel,
|
43 | 43 | )
|
| 44 | +from ads.common import oci_client as oc |
44 | 45 |
|
45 | 46 | logger = logging.getLogger(__name__)
|
46 | 47 |
|
@@ -1466,3 +1467,226 @@ def _download_file_description_artifact(self) -> Tuple[Union[str, List[str]], in
|
1466 | 1467 | bucket_uri.append(uri)
|
1467 | 1468 |
|
1468 | 1469 | return bucket_uri[0] if len(bucket_uri) == 1 else bucket_uri, artifact_size
|
| 1470 | + |
| 1471 | + def add_artifact( |
| 1472 | + self, |
| 1473 | + uri: Optional[str] = None, |
| 1474 | + namespace: Optional[str] = None, |
| 1475 | + bucket: Optional[str] = None, |
| 1476 | + prefix: Optional[str] = None, |
| 1477 | + files: Optional[List[str]] = None, |
| 1478 | + ): |
| 1479 | + """ |
| 1480 | + Adds information about objects in a specified bucket to the model description JSON. |
| 1481 | +
|
| 1482 | + Parameters |
| 1483 | + ---------- |
| 1484 | + uri : str, optional |
| 1485 | + The URI representing the location of the artifact in OCI object storage. |
| 1486 | + namespace : str, optional |
| 1487 | + The namespace of the bucket containing the objects. Required if `uri` is not provided. |
| 1488 | + bucket : str, optional |
| 1489 | + The name of the bucket containing the objects. Required if `uri` is not provided. |
| 1490 | + prefix : str, optional |
| 1491 | + The prefix of the objects to add. Defaults to None. Cannot be provided if `files` is provided. |
| 1492 | + files : list of str, optional |
| 1493 | + A list of file names to include in the model description. If provided, only objects with matching file names will be included. Cannot be provided if `prefix` is provided. |
| 1494 | +
|
| 1495 | + Returns |
| 1496 | + ------- |
| 1497 | + None |
| 1498 | +
|
| 1499 | + Raises |
| 1500 | + ------ |
| 1501 | + ValueError |
| 1502 | + - If both `uri` and (`namespace` and `bucket`) are provided. |
| 1503 | + - If neither `uri` nor both `namespace` and `bucket` are provided. |
| 1504 | + - If both `prefix` and `files` are provided. |
| 1505 | + - If no files are found to add to the model description. |
| 1506 | +
|
| 1507 | + Note |
| 1508 | + ---- |
| 1509 | + - If `files` is not provided, it retrieves information about all objects in the bucket. |
| 1510 | + - If `files` is provided, it only retrieves information about objects with matching file names. |
| 1511 | + - If no objects are found to add to the model description, a ValueError is raised. |
| 1512 | + """ |
| 1513 | + |
| 1514 | + if uri and (namespace or bucket): |
| 1515 | + raise ValueError( |
| 1516 | + "Either 'uri' must be provided or both 'namespace' and 'bucket' must be provided." |
| 1517 | + ) |
| 1518 | + if uri: |
| 1519 | + object_storage_details = ObjectStorageDetails.from_path(uri) |
| 1520 | + bucket = object_storage_details.bucket |
| 1521 | + namespace = object_storage_details.namespace |
| 1522 | + prefix = ( |
| 1523 | + None |
| 1524 | + if object_storage_details.filepath == "" |
| 1525 | + else object_storage_details.filepath |
| 1526 | + ) |
| 1527 | + if (not namespace) or (not bucket): |
| 1528 | + raise ValueError("Both 'namespace' and 'bucket' must be provided.") |
| 1529 | + |
| 1530 | + # Check if both prefix and files are provided |
| 1531 | + if prefix is not None and files is not None: |
| 1532 | + raise ValueError( |
| 1533 | + "Both 'prefix' and 'files' cannot be provided. Please provide only one." |
| 1534 | + ) |
| 1535 | + |
| 1536 | + if self.model_file_description == None: |
| 1537 | + self.empty_json = { |
| 1538 | + "version": "1.0", |
| 1539 | + "type": "modelOSSReferenceDescription", |
| 1540 | + "models": [], |
| 1541 | + } |
| 1542 | + self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, self.empty_json) |
| 1543 | + |
| 1544 | + # Get object storage client |
| 1545 | + self.object_storage_client = oc.OCIClientFactory( |
| 1546 | + **(self.dsc_model.auth) |
| 1547 | + ).object_storage |
| 1548 | + |
| 1549 | + # Remove if the model already exists |
| 1550 | + self.remove_artifact(namespace=namespace, bucket=bucket, prefix=prefix) |
| 1551 | + |
| 1552 | + def check_if_file_exists(fileName): |
| 1553 | + isExists = False |
| 1554 | + try: |
| 1555 | + headResponse = self.object_storage_client.head_object( |
| 1556 | + namespace, bucket, object_name=fileName |
| 1557 | + ) |
| 1558 | + if headResponse.status == 200: |
| 1559 | + isExists = True |
| 1560 | + except Exception as e: |
| 1561 | + if hasattr(e, "status") and e.status == 404: |
| 1562 | + logger.error(f"File not found in bucket: {fileName}") |
| 1563 | + else: |
| 1564 | + logger.error(f"An error occured: {e}") |
| 1565 | + return isExists |
| 1566 | + |
| 1567 | + # Function to un-paginate the api call with while loop |
| 1568 | + def list_obj_versions_unpaginated(): |
| 1569 | + objectStorageList = [] |
| 1570 | + has_next_page, opc_next_page = True, None |
| 1571 | + while has_next_page: |
| 1572 | + response = self.object_storage_client.list_object_versions( |
| 1573 | + namespace_name=namespace, |
| 1574 | + bucket_name=bucket, |
| 1575 | + prefix=prefix, |
| 1576 | + fields="name,size", |
| 1577 | + page=opc_next_page, |
| 1578 | + ) |
| 1579 | + objectStorageList.extend(response.data.items) |
| 1580 | + has_next_page = response.has_next_page |
| 1581 | + opc_next_page = response.next_page |
| 1582 | + return objectStorageList |
| 1583 | + |
| 1584 | + # Fetch object details and put it into the objects variable |
| 1585 | + objectStorageList = [] |
| 1586 | + if files == None: |
| 1587 | + objectStorageList = list_obj_versions_unpaginated() |
| 1588 | + else: |
| 1589 | + for fileName in files: |
| 1590 | + if check_if_file_exists(fileName=fileName): |
| 1591 | + objectStorageList.append( |
| 1592 | + self.object_storage_client.list_object_versions( |
| 1593 | + namespace_name=namespace, |
| 1594 | + bucket_name=bucket, |
| 1595 | + prefix=fileName, |
| 1596 | + fields="name,size", |
| 1597 | + ).data.items[0] |
| 1598 | + ) |
| 1599 | + |
| 1600 | + objects = [ |
| 1601 | + {"name": obj.name, "version": obj.version_id, "sizeInBytes": obj.size} |
| 1602 | + for obj in objectStorageList |
| 1603 | + if obj.size > 0 |
| 1604 | + ] |
| 1605 | + |
| 1606 | + if len(objects) == 0: |
| 1607 | + error_message = ( |
| 1608 | + f"No files to add in the bucket: {bucket} with namespace: {namespace} " |
| 1609 | + f"and prefix: {prefix}. File names: {files}" |
| 1610 | + ) |
| 1611 | + logger.error(error_message) |
| 1612 | + raise ValueError(error_message) |
| 1613 | + |
| 1614 | + tmp_model_file_description = self.model_file_description |
| 1615 | + tmp_model_file_description["models"].append( |
| 1616 | + { |
| 1617 | + "namespace": namespace, |
| 1618 | + "bucketName": bucket, |
| 1619 | + "prefix": "" if not prefix else prefix, |
| 1620 | + "objects": objects, |
| 1621 | + } |
| 1622 | + ) |
| 1623 | + self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, tmp_model_file_description) |
| 1624 | + |
| 1625 | + def remove_artifact( |
| 1626 | + self, |
| 1627 | + uri: Optional[str] = None, |
| 1628 | + namespace: Optional[str] = None, |
| 1629 | + bucket: Optional[str] = None, |
| 1630 | + prefix: Optional[str] = None, |
| 1631 | + ): |
| 1632 | + """ |
| 1633 | + Removes information about objects in a specified bucket or using a specified URI from the model description JSON. |
| 1634 | +
|
| 1635 | + Parameters |
| 1636 | + ---------- |
| 1637 | + uri : str, optional |
| 1638 | + The URI representing the location of the artifact in OCI object storage. |
| 1639 | + namespace : str, optional |
| 1640 | + The namespace of the bucket containing the objects. Required if `uri` is not provided. |
| 1641 | + bucket : str, optional |
| 1642 | + The name of the bucket containing the objects. Required if `uri` is not provided. |
| 1643 | + prefix : str, optional |
| 1644 | + The prefix of the objects to remove. Defaults to None. |
| 1645 | +
|
| 1646 | + Returns |
| 1647 | + ------- |
| 1648 | + None |
| 1649 | +
|
| 1650 | + Raises |
| 1651 | + ------ |
| 1652 | + ValueError |
| 1653 | + - If both 'uri' and ('namespace' and 'bucket') are provided. |
| 1654 | + - If neither 'uri' nor both 'namespace' and 'bucket' are provided. |
| 1655 | + - If the model description JSON is None. |
| 1656 | + """ |
| 1657 | + |
| 1658 | + if uri and (namespace or bucket): |
| 1659 | + raise ValueError( |
| 1660 | + "Either 'uri' must be provided or both 'namespace' and 'bucket' must be provided." |
| 1661 | + ) |
| 1662 | + if uri: |
| 1663 | + object_storage_details = ObjectStorageDetails.from_path(uri) |
| 1664 | + bucket = object_storage_details.bucket |
| 1665 | + namespace = object_storage_details.namespace |
| 1666 | + prefix = ( |
| 1667 | + None |
| 1668 | + if object_storage_details.filepath == "" |
| 1669 | + else object_storage_details.filepath |
| 1670 | + ) |
| 1671 | + if (not namespace) or (not bucket): |
| 1672 | + raise ValueError("Both 'namespace' and 'bucket' must be provided.") |
| 1673 | + |
| 1674 | + def findModelIdx(): |
| 1675 | + for idx, model in enumerate(self.model_file_description["models"]): |
| 1676 | + if ( |
| 1677 | + model["namespace"], |
| 1678 | + model["bucketName"], |
| 1679 | + (model["prefix"] if ("prefix" in model) else None), |
| 1680 | + ) == (namespace, bucket, "" if not prefix else prefix): |
| 1681 | + return idx |
| 1682 | + return -1 |
| 1683 | + |
| 1684 | + if self.model_file_description == None: |
| 1685 | + return |
| 1686 | + |
| 1687 | + modelSearchIdx = findModelIdx() |
| 1688 | + if modelSearchIdx == -1: |
| 1689 | + return |
| 1690 | + else: |
| 1691 | + # model found case |
| 1692 | + self.model_file_description["models"].pop(modelSearchIdx) |
0 commit comments