|
9 | 9 | from django_filters import FilterSet
|
10 | 10 | from django_filters import UUIDFilter
|
11 | 11 | from django_filters.rest_framework import DjangoFilterBackend
|
| 12 | +from django_tenants.utils import schema_context |
12 | 13 | from rest_framework import mixins
|
13 | 14 | from rest_framework import viewsets
|
14 | 15 | from rest_framework.permissions import AllowAny
|
|
17 | 18 |
|
18 | 19 | from api.common.filters import CharListFilter
|
19 | 20 | from api.provider.models import Sources
|
| 21 | +from cost_models.models import CostModelMap |
20 | 22 | from masu.api.sources.serializers import SourceSerializer
|
21 | 23 |
|
22 | 24 | MIXIN_LIST = [mixins.ListModelMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, viewsets.GenericViewSet]
|
@@ -97,3 +99,24 @@ def get_object(self):
|
97 | 99 | raise Http404
|
98 | 100 |
|
99 | 101 | return obj
|
| 102 | + |
| 103 | + def list(self, request, *args, **kwargs): |
| 104 | + """Obtain the list of sources.""" |
| 105 | + response = super().list(request=request, args=args, kwargs=kwargs) |
| 106 | + for obj in response.data["data"]: |
| 107 | + obj["cost_models"] = self.get_cost_models(obj) |
| 108 | + return response |
| 109 | + |
| 110 | + def retrieve(self, request, *args, **kwargs): |
| 111 | + """Get a source.""" |
| 112 | + response = super().retrieve(request=request, args=args, kwargs=kwargs) |
| 113 | + response.data["cost_models"] = self.get_cost_models(response.data) |
| 114 | + return response |
| 115 | + |
| 116 | + def get_cost_models(self, obj): |
| 117 | + """Get the cost models associated with this provider.""" |
| 118 | + if not (schema := obj.get("provider", {}).get("customer", {}).get("schema_name")): |
| 119 | + return [] |
| 120 | + with schema_context(schema): |
| 121 | + cost_models_map = CostModelMap.objects.filter(provider_uuid=obj["source_uuid"]) |
| 122 | + return [{"name": m.cost_model.name, "uuid": m.cost_model.uuid} for m in cost_models_map] |
0 commit comments