Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
nose==1.3.7
wsgiref==0.1.2
sphinx
sphinx-autobuild
sphinx-rtd-theme
-e git+https://github.com/caxiam/model-api.git#egg=Package
sphinx-rtd-theme
6 changes: 3 additions & 3 deletions rest_orm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
import errors
import fields
import models
from rest_orm import errors
from rest_orm import fields
from rest_orm import models
77 changes: 34 additions & 43 deletions rest_orm/fields.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from decimal import Decimal
from decimal import Decimal as PyDecimal

from rest_orm.utils import get_class


class AdaptedField(object):
class Field(object):
"""Flat representaion of remote endpoint's field.

`AdaptedField` and its child classes are self-destructive. Once
`Field` and its child classes are self-destructive. Once
deserialization is complete, the instance is replaced by the typed
value retrieved.
"""

def __init__(self, path, missing=None, nullable=True, required=False,
validate=None):
def __init__(self, path, missing=None):
"""Key extraction strategy and settings.

:param path: A formattable string path.
:param missing: The default deserialization value.
:param nullable: If `False`, disallow `None` type values.
:param required: If `True`, raise an error if the key is missing.
:param validate: A callable object.
"""
self.path = path
self.missing = missing
self.nullable = nullable
self.required = required
self.validate = validate

def deserialize(self, data):
"""Extract a value from the provided data object.
Expand All @@ -38,28 +31,16 @@ def deserialize(self, data):
return self._deserialize(data)

try:
raw_value = self.map_from_string(self.path, data)
value = self.map_from_string(self.path, data)
except (KeyError, IndexError):
if self.required:
raise KeyError('{} not found.'.format(self.path))
value = self.missing
else:
if raw_value is None and self.nullable:
value = None
else:
value = self._deserialize(raw_value)

self._validate(value)
return value
if value is None:
return value
return self._deserialize(value)

def _deserialize(self, value):
return value

def _validate(self, value):
if self.validate is not None:
self.validate(value)
return None

def map_from_string(self, path, data):
"""Return nested value from the string path taken.

Expand All @@ -77,50 +58,60 @@ def extract_by_type(path):
return data


class AdaptedBoolean(AdaptedField):
class Boolean(Field):
"""Parse an adapted field into the boolean type."""

def _deserialize(self, value):
return bool(value)


class AdaptedDate(AdaptedField):
class Date(Field):
"""Parse an adapted field into the datetime type."""

def __init__(self, *args, **kwargs):
self.date_format = kwargs.pop('date_format', '%Y-%m-%d')
super(AdaptedDate, self).__init__(*args, **kwargs)
super(Date, self).__init__(*args, **kwargs)

def _deserialize(self, value):
return datetime.strptime(value, self.date_format)


class AdaptedDecimal(AdaptedField):
class Dump(Field):
"""Return a pre-determined value."""

def __init__(self, value):
self.value = value

def deserialize(self, data):
return self.value


class Decimal(Field):
"""Parse an adapted field into the decimal type."""

def _deserialize(self, value):
return Decimal(value)
return PyDecimal(value)


class AdaptedInteger(AdaptedField):
class Integer(Field):
"""Parse an adapted field into the integer type."""

def _deserialize(self, value):
return int(value)


class AdaptedFunction(AdaptedField):
class Function(Field):
"""Parse an adapted field into a specified function's output."""

def __init__(self, f, *args, **kwargs):
self.f = f
super(AdaptedFunction, self).__init__(*args, **kwargs)
super(Function, self).__init__(*args, **kwargs)

def _deserialize(self, value):
return self.f(value)


class AdaptedList(AdaptedField):
class List(Field):
"""Parse an adapted field into the list type."""

def _deserialize(self, value):
Expand All @@ -129,20 +120,20 @@ def _deserialize(self, value):
return value


class AdaptedNested(AdaptedField):
"""Parse an adatped field into the AdaptedModel type."""
class Nested(Field):
"""Parse an adatped field into the Model type."""

def __init__(self, model, *args, **kwargs):
"""Parse a list of nested objects into an AdaptedModel.
"""Parse a list of nested objects into an Model.

:param model: AdaptedModel name or reference.
:param model: Model name or reference.
"""
self.nested_model = model
super(AdaptedNested, self).__init__(*args, **kwargs)
super(Nested, self).__init__(*args, **kwargs)

@property
def model(self):
"""Return an AdaptedModel reference."""
"""Return an Model reference."""
if isinstance(self.nested_model, str):
return get_class(self.nested_model)
return self.nested_model
Expand All @@ -153,7 +144,7 @@ def _deserialize(self, value):
return self.model().load(value)


class AdaptedString(AdaptedField):
class String(Field):
"""Parse an adapted field into the string type."""

def _deserialize(self, value):
Expand Down
43 changes: 22 additions & 21 deletions rest_orm/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
from rest_orm.fields import AdaptedField
from copy import copy

from rest_orm.fields import Field
from rest_orm.utils import ModelRegistry

import json
Expand All @@ -11,34 +13,33 @@ class BaseModel(object):
__metaclass__ = ModelRegistry


class AdaptedModel(BaseModel):
class Model(BaseModel):
"""A flat representation of a single remote endpoint."""

def loads(self, data):
"""Load a JSON string to a flattened dictionary."""
return self.load(json.loads(data))

def load(self, data):
"""Flatten a nested dictionary."""
response = {}
for field_name in dir(self):
field = getattr(self, field_name)
if not isinstance(field, Field):
continue
response[field_name] = copy(field.deserialize(data))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@morgan This is the relevant part that was causing the issue. Using copy prevents an accidental reference to the field class from being passed and updated.


response = self.post_load(response)
return response

def connect(self, *args, **kwargs):
"""Make a request to a remote endpoint and load its JSON response."""
response = self.make_request(*args, **kwargs)
return self.loads(response)

def loads(self, response):
"""Marshal a JSON response object into the model."""
return self.load(json.loads(response))

def load(self, response):
"""Marshal a python dictionary object into the model."""
self._do_load(response)
self.post_load()
return self

def post_load(self):
def post_load(self, data):
"""Perform any model level actions after load."""
pass

def _do_load(self, data):
for field_name in dir(self):
field = getattr(self, field_name)
if not isinstance(field, AdaptedField):
continue
setattr(self, field_name, field.deserialize(data))
return data

def make_request(self):
"""Return the response data of a remote endpoint."""
Expand Down
Loading